In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from torch.autograd import Variable
from torch.utils.data import DataLoader
from data_prep.dataset import SkeletonDataset, SkeletonDatasetFromDirectory

import st_gcn_parser
import pandas as pd
import argparse
import os
import sys
import random
import time
import json


class Processor:
    """ST-GCN processing wrapper for training and testing the model.

    Methods:
        train()
            Trains the model, given user-defined training parameters.

        test()
            Performs only the forward pass for inference.

    TODO:
        ``1.`` Provide useful prediction statistics (e.g. IoU, jitter, etc.).
    """

    def __init__(
        self,
        model,
        num_classes,
        dataloader,
        device):
        """
        Args:
            model : ``torch.nn.Module``
                Configured PyTorch model.
            
            num_classes : ``int``
                Number of action classification classes.

            dataloader : ``torch.utils.data.DataLoader``
                Data handle to account for its class imbalance in the CE loss.
        """

        classes = torch.tensor(range(num_classes), dtype=torch.float32)
        class_dist = torch.zeros(num_classes, dtype=torch.float32)

        for _, labels in dataloader:
            class_dist += torch.sum(
                (labels[:,:,None].to(torch.float32) == classes[None].expand(labels.shape[1],-1)).to(torch.float32),
                dim=(0,1))

        self.model = model
        self.ce = nn.CrossEntropyLoss(weight=(1-class_dist/torch.sum(class_dist)).to(device=device), reduction='mean')
        self.mse = nn.MSELoss(reduction='none')
        self.num_classes = num_classes


    def update_lr_(self, learning_rate, learning_rate_decay, epoch):
        """Decays learning rate monotonically by the provided factor."""
        
        rate = learning_rate * pow(learning_rate_decay, epoch)
        for g in self.optimizer.param_groups:
            g['lr'] = rate


    def forward_(
        self,
        captures,
        labels,
        device,
        **kwargs):
        """Does the forward pass on the model.
        
        If `dataset_type` is `'dir'`, processes 1 trial at a time, chops each sequence 
        into equal segments that are split across available executors (GPUs) for parallel computation.
        
        If `model` is `'original'` and `latency` is `True`, applies the original classification model
        on non-overlapping windows of size `receptive_field` over the input stream, producing outputs at a 
        reduced temporal resolution inversely proportional to the size of the window. Trades prediction
        resolution for compute (does not compute redundant values for input frames otherwise overlapped by 
        multiple windows).
        """

        # move both data to the compute device
        # (captures is a batch of full-length captures, label is a batch of ground truths)
        captures, labels = captures.to(device), labels.to(device)

        N, _, L, _ = captures.size()

        # Splits trial into overlapping subsequences of samples
        if kwargs['dataset_type'] == 'dir': 
            if kwargs['model'] == 'original':
                # zero pad the input across time from start by the receptive field size
                captures = F.pad(captures, (0, 0, kwargs['receptive_field']-1, 0))
                stride = kwargs['receptive_field'] if kwargs['latency'] else 1
                captures = captures.unfold(2, kwargs['receptive_field'], stride)
                labels = labels[:, ::stride]
            else:
                # Size to divide the trial into to construct a data parallel batch
                # TODO: adjust if kernel is different in multi-stage ST-GCN
                P = kwargs['segment']-(kwargs['kernel'][0]-1)-(L-kwargs['segment'])%(kwargs['segment']-(kwargs['kernel'][0]-1))
                # Pad the end of the sequence to use all of the available readings (masks actual outputs later)
                # TODO: if captures is perfectly unfolded without padding, below call will create a slice of all 0's. Put a conditional to prevent that.
                captures = F.pad(captures, (0, 0, 0, P))
                captures = captures.unfold(2, kwargs['segment'], kwargs['segment']-(kwargs['kernel'][0]-1))
            
            N, C, N_new, V, T_new = captures.size()
            # (N,C,N',V,T') -> batches of unfolded slices
            captures = captures.permute(0, 2, 1, 4, 3).contiguous()
            captures = captures.view(N * N_new, C, T_new, V)
            # (N'',C,T',V)

        # make predictions and compute the loss
        # forward pass the minibatch through the model for the corresponding subject
        # the input tensor has shape (N, V, C, L): N-batch, V-nodes, C-channels, L-length
        # the output tensor has shape (N, C', L)
        predictions = self.model(Variable(captures, requires_grad=True))

        if kwargs['dataset_type'] == 'dir':
            C_new = predictions.size(1)
            if kwargs['model'] == 'original':
                # arrange tensor back into a time series
                predictions = predictions.view(N, N_new, C_new)
                predictions = predictions.permute(0, 2, 1)
            else:
                # clone the tensor from the unfolded view to avoid overwriting underlying data that is viewed in multiple slices
                predictions = torch.clone(predictions)
                # clear the overlapping Gamma-1 predictions at the start of each segment (except the very first segment), since 
                # overlapped regions are added when folding the tensor
                predictions[1:,:,:kwargs['kernel'][0]-1] = 0
                # shuffle data around for the correct contiguous access by the fold()
                predictions = predictions[None].permute(0, 2, 3, 1).contiguous()
                predictions = predictions.view(N, C_new * kwargs['segment'], -1)
                # fold segments of the original trial computed in parallel on multiple executors back into original length sequence
                # and drop the end padding used to fill tensor to equal row-column size
                predictions = F.fold(
                    predictions, 
                    output_size=(1, L+P), 
                    kernel_size=(1, kwargs['segment']), 
                    stride=(1, kwargs['segment']-(kwargs['kernel'][0]-1)))[:,:,0,:L]

        # cross-entropy expects output as class indices (N, C, K), with labels (N, K): 
        # N-batch (flattened multi-skeleton minibatch), C-class, K-extra dimension (capture length)
        # CE + MSE loss metric tuning is taken from @BenjaminFiltjens's MS-GCN:
        # CE guides model toward absolute correctness on single frame predictions,
        # MSE component punishes large variations in class probabilities between consecutive samples
        ce = self.ce(predictions, labels)
        # In the reduced temporal resolution setting of the original model, MSE loss is expected to be large the higher
        # the receptive field since after that many frames a human could start performing a drastically diferent action
        mse = 0.15 * torch.mean(
            torch.clamp(
                self.mse(
                    F.log_softmax(predictions[:,:,1:], dim=1), 
                    F.log_softmax(predictions.detach()[:,:,:-1], dim=1)),
                min=0,
                max=16))

        # calculate the predictions statistics
        # this only sums the number of top-1 correctly predicted frames, but doesn't look at prediction jitter
        _, top5_predicted = torch.topk(predictions, k=5, dim=1)
        top1_predicted = top5_predicted[:,0,:]

        top1_cor = torch.sum(top1_predicted == labels).data.item()
        top5_cor = torch.sum(top5_predicted == labels[:,None,:]).data.item()
        tot = labels.numel()

        return top1_predicted, top5_predicted, top1_cor, top5_cor, tot, ce, mse


    def validate_(
        self,
        dataloader,
        device,
        **kwargs):
        """Does a forward pass without recording gradients. 

        Shared between train and test scripts: train invokes it after each epoch trained,
        test invokes it once for inference only.
        """

        # do not record gradients
        with torch.no_grad():    
            top1_correct = 0
            top5_correct = 0
            total = 0

            test_start_time = time.time()

            confusion_matrix = torch.zeros(self.num_classes, self.num_classes, device=device)
            total_per_class = torch.zeros(self.num_classes, 1, device=device)
            
            ce_epoch_loss_val = 0
            mse_epoch_loss_val = 0

            # sweep through the training dataset in minibatches
            for captures, labels in dataloader:
                top1_predicted, _, top1_cor, top5_cor, tot, ce, mse = self.forward_(captures, labels, device, **kwargs)

                top1_correct += top1_cor
                top5_correct += top5_cor
                total += tot

                stride = kwargs['receptive_field'] if kwargs['latency'] else 1
                labels = labels[:, ::stride]
                N, L = labels.size()
                
                # epoch loss has to multiply by minibatch size to get total non-averaged loss, 
                # which will then be averaged across the entire dataset size, since
                # loss for dataset with equal-length trials averages the CE and MSE losses for each minibatch
                # (used for statistics)
                ce_epoch_loss_val += (ce*N).data.item()
                mse_epoch_loss_val += (mse*N).data.item()

                # delete unnecessary computational graph references to clear space
                del ce, mse

                # collect the correct predictions for each class and total per that class
                # for batch_el in range(N*M):
                for batch_el in range(N):
                    top1_predicted_ohe = torch.zeros(L, self.num_classes, device=device)
                    top1_predicted_ohe[range(L), top1_predicted[batch_el]] = 1
                    confusion_matrix[labels[batch_el, 0]] += top1_predicted_ohe.sum(dim=0)
                    total_per_class[labels[batch_el, 0]] += L

            test_end_time = time.time()

            # normalize each row of the confusion matrix to obtain class probabilities
            confusion_matrix = torch.div(confusion_matrix, total_per_class)

            top1_acc = top1_correct / total
            top5_acc = top5_correct / total
            duration = test_end_time - test_start_time

        return top1_acc, top5_acc, duration, confusion_matrix, ce_epoch_loss_val, mse_epoch_loss_val


    def train(
        self, 
        save_dir, 
        train_dataloader,
        val_dataloader,
        device,
        epochs,
        checkpoints,
        checkpoint,
        learning_rate,
        learning_rate_decay,
        **kwargs):
        """Trains the model, given user-defined training parameters."""

        # move the model to the compute device(s) if available (CPU, GPU, TPU, etc.)
        if torch.cuda.device_count() > 1:
            print("Using", torch.cuda.device_count(), "allocated GPUs", flush=True, file=kwargs['log'][0])
            self.model = nn.DataParallel(self.model)
        self.model.to(device)

        if checkpoint:
            state = torch.load(checkpoint, map_location=device)
            range_epochs = range(state['epoch']+1, epochs)
        else:
            range_epochs = range(epochs)

        # setup the optimizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        # load the checkpoint if not training from scratch
        if checkpoint:
            self.optimizer.load_state_dict(state['optimizer_state_dict'])

        # variables for email updates
        epoch_list = []

        top1_acc_train_list = []
        top1_acc_val_list = []
        top5_acc_train_list = []
        top5_acc_val_list = []
        duration_train_list = []
        duration_val_list = []

        ce_loss_train_list = []
        mse_loss_train_list = []
        epoch_loss_train_list = []

        ce_loss_val_list = []
        mse_loss_val_list = []
        epoch_loss_val_list = []

        # train the model for num_epochs
        # (dataloader is automatically shuffled after each epoch)
        for epoch in range_epochs:
            # set layers to training mode if behavior of any differs between train and prediction
            # (prepares Dropout and BatchNormalization layers to disable and to learn parameters, respectively)
            self.model.train()

            ce_epoch_loss_train = 0
            mse_epoch_loss_train = 0

            ce_loss = 0
            mse_loss = 0

            top1_correct = 0
            top5_correct = 0
            total = 0

            # decay learning rate every 10 epochs [ref: Yan 2018]
            if (epoch % 10 == 0):
                self.update_lr_(learning_rate, learning_rate_decay, epoch//10)

            epoch_start_time = time.time()

            self.optimizer.zero_grad()

            # sweep through the training dataset in minibatches
            for i, (captures, labels) in enumerate(train_dataloader):
                N, _, _, _ = captures.size()

                _, _, top1_cor, top5_cor, tot, ce, mse = self.forward_(captures, labels, device, **kwargs)

                top1_correct += top1_cor
                top5_correct += top5_cor
                total += tot

                # epoch loss has to multiply by minibatch size to get total non-averaged loss, 
                # which will then be averaged across the entire dataset size, since
                # loss for dataset with equal-length trials averages the CE and MSE losses for each minibatch
                # (used for statistics)
                ce_epoch_loss_train += (ce*N).data.item()
                mse_epoch_loss_train += (mse*N).data.item()

                # accumulate losses (used for backpropagation)
                ce_loss += ce
                mse_loss += mse
                # delete unnecessary computational graph references to clear space
                del ce, mse

                # zero the gradient buffers after every batch
                # if dataset is a tensor with equal length trials, always enters
                # if dataset is a set of different length trials, enters every ``batch_size`` iteration
                if ((kwargs['dataset_type'] == 'dir' and
                        ((i + 1) % kwargs['batch_size'] == 0 or 
                        (i + 1) == len(train_dataloader))) or
                    (kwargs['dataset_type'] == 'file')):

                    # loss is already a mean across minibatch for tensor of equally long trials, but
                    # not for different-length trials -> needs averaging
                    loss = ce_loss + mse_loss
                    if (kwargs['dataset_type'] == 'dir' and (i + 1) % kwargs['batch_size'] == 0):
                        # if the minibatch is the same size as requested (first till one before last minibatch)
                        loss /= kwargs['batch_size']
                    elif (kwargs['dataset_type'] == 'dir' and (i + 1) == len(train_dataloader)):
                        # if the minibatch is smaller than requested (last minibatch)
                        loss /= ((i + 1) % kwargs['batch_size'])

                    # backward pass to compute the gradients
                    loss.backward()

                    # update parameters based on the computed gradients
                    self.optimizer.step()

                    # clear the loss
                    ce_loss = 0
                    mse_loss = 0
                    del loss

                    # clear the gradients
                    self.optimizer.zero_grad()

            epoch_end_time = time.time()
            duration_train = epoch_end_time - epoch_start_time
            top1_acc_train = top1_correct / total
            top5_acc_train = top5_correct / total
            
            # checkpoint the model during training at specified epochs
            if epoch in checkpoints:
                torch.save({
                    "epoch": epoch,
                    "model_state_dict": self.model.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "loss": (ce_epoch_loss_train + mse_epoch_loss_train) / len(train_dataloader),
                    }, "{0}/epoch-{1}.pt".format(save_dir, epoch))
            
            # set layers to inference mode if behavior differs between train and prediction
            # (prepares Dropout and BatchNormalization layers to enable and to freeze parameters, respectively)
            self.model.eval()

            # test the model on the validation set
            # will complain on CUDA devices that input gradients are none: irrelevant because it is a side effect of
            # the shared `forward_()` routine for both tasks, where the model is set to `train()` or `eval()` in the
            # corresponding caller function
            top1_acc_val, top5_acc_val, duration_val, confusion_matrix, ce_epoch_loss_val, mse_epoch_loss_val = self.validate_(
                dataloader=val_dataloader,
                device=device,
                **kwargs)

            # record all stats of interest for logging/notification
            epoch_list.insert(0, epoch)

            ce_loss_train_list.insert(0, ce_epoch_loss_train / len(train_dataloader))
            mse_loss_train_list.insert(0, mse_epoch_loss_train / len(train_dataloader))
            epoch_loss_train_list.insert(0, (ce_epoch_loss_train + mse_epoch_loss_train) / len(train_dataloader))

            ce_loss_val_list.insert(0, ce_epoch_loss_val / len(val_dataloader))
            mse_loss_val_list.insert(0, mse_epoch_loss_val / len(val_dataloader))
            epoch_loss_val_list.insert(0, (ce_epoch_loss_val + mse_epoch_loss_val) / len(val_dataloader))

            top1_acc_train_list.insert(0, top1_acc_train)
            top1_acc_val_list.insert(0, top1_acc_val)
            top5_acc_train_list.insert(0, top5_acc_train)
            top5_acc_val_list.insert(0, top5_acc_val)            
            duration_train_list.insert(0, duration_train)
            duration_val_list.insert(0, duration_val)

            # save confusion matrix as a CSV file
            pd.DataFrame(confusion_matrix.cpu().numpy()).to_csv('{0}/confusion_matrix_epoch-{1}.csv'.format(save_dir, epoch))

            # log and send notifications
            print(
                "[epoch {0}]: epoch loss = {1}, top1_acc_train = {2}, top5_acc_train = {3}, top1_acc_val = {4}, top5_acc_val = {5}"
                .format(
                    epoch, 
                    (ce_epoch_loss_train + mse_epoch_loss_train) / len(train_dataloader),
                    top1_acc_train,
                    top5_acc_train,
                    top1_acc_val,
                    top5_acc_val),
                flush=True,
                file=kwargs['log'][0])
            
            if kwargs['verbose'] > 0:
                print(
                    "[epoch {0}]: train_time = {1}, val_time = {2}"
                    .format(
                        epoch,
                        duration_train,
                        duration_val),
                    flush=True,
                    file=kwargs['log'][0])
            
            # format a stats table (in newest to oldest order) and send it as email update
            if kwargs['verbose'] > 1:
                os.system(
                    'header="\n %-6s %5s %11s %11s %9s %9s %11s %9s\n";'
                    'format=" %-03d %4.6f %1.4f %1.4f %1.4f %1.4f %5.6f %5.6f\n";'
                    'printf "$header" "EPOCH" "LOSS" "TOP1_TRAIN" "TOP5_TRAIN" "TOP1_VAL" "TOP5_VAL" "TIME_TRAIN" "TIME_VAL" > $PBS_O_WORKDIR/mail_draft_{1}.txt;'
                    'printf "$format" {0} >> $PBS_O_WORKDIR/mail_draft_{1}.txt;'
                    'cat $PBS_O_WORKDIR/mail_draft_{1}.txt | mail -s "[{1}]: $PBS_JOBNAME status update" {2}'
                    .format(
                        ' '.join([
                            ' '.join([str(e) for e in t]) for t 
                            in zip(
                                epoch_list,
                                epoch_loss_train_list,
                                top1_acc_train_list,
                                top5_acc_train_list,
                                top1_acc_val_list,
                                top5_acc_val_list,
                                duration_train_list,
                                duration_val_list)]),
                        os.getenv('PBS_JOBID').split('.')[0],
                        kwargs['email']))

            # save (update) train-validation curve as a CSV file after each epoch
            pd.DataFrame(
                data={
                    'top1_train': top1_acc_train_list,
                    'top1_val': top1_acc_val_list,
                    'top5_train': top5_acc_train_list,
                    'top5_val': top5_acc_val_list
                }).to_csv('{0}/accuracy-curve.csv'.format(save_dir))

            # save (update) loss curve as a CSV file after each epoch
            pd.DataFrame(
                data={
                    'ce_train': ce_loss_train_list,
                    'mse_train': mse_loss_train_list,
                    'ce_val': ce_loss_val_list,
                    'mse_val': mse_loss_val_list,
                }).to_csv('{0}/train-validation-curve.csv'.format(save_dir))

        # save the final model
        torch.save({
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "loss": (ce_epoch_loss_train + mse_epoch_loss_train) / len(train_dataloader),
            }, "{0}/final.pt".format(save_dir))

        return


    def test(
        self,
        save_dir,
        dataloader,
        device,
        **kwargs):
        """Performs only the forward pass for inference.
        """
        
        # set layers to inference mode if behavior differs between train and prediction
        # (prepares Dropout and BatchNormalization layers to enable and to freeze parameters, respectively)
        self.model.eval()
        
        # move the model to the compute device(s) if available (CPU, GPU, TPU, etc.)
        if torch.cuda.device_count() > 1:
            print("Using", torch.cuda.device_count(), "allocated GPUs", flush=True, file=kwargs['log'][0])
            self.model = nn.DataParallel(self.model)
        self.model.to(device)

        # test the model on the validation set
        top1_acc_val, top5_acc_val, duration_val, confusion_matrix, ce_epoch_loss_val, mse_epoch_loss_val = self.validate_(
            dataloader=dataloader,
            device=device,
            **kwargs)
        
        # save confusion matrix as a CSV file
        pd.DataFrame(confusion_matrix.cpu().numpy()).to_csv('{0}/confusion_matrix.csv'.format(save_dir))

        # log and send notifications
        print(
            "[test]: top1_acc = {0}, top5_acc = {1}"
            .format( 
                top1_acc_val,
                top5_acc_val),
            flush=True,
            file=kwargs['log'][0])
        
        if kwargs['verbose'] > 0:
            print(
                "[test]: time = {1}"
                .format(duration_val),
                flush=True,
                file=kwargs['log'][0])

        # format a stats table (in newest to oldest order) and send it as email update
        if kwargs['verbose'] > 1:
            os.system(
                'header="\n %-5s %5s %5s\n";'
                'format=" %-1.4f %1.4f %5.6f\n";'
                'printf "$header" "TOP1" "TOP5" "TIME" > $PBS_O_WORKDIR/mail_draft_{1}.txt;'
                'printf "$format" {0} >> $PBS_O_WORKDIR/mail_draft_{1}.txt;'
                'cat $PBS_O_WORKDIR/mail_draft_{1}.txt | mail -s "[{1}]: $PBS_JOBNAME status update" {2}'
                .format(
                    ' '.join([
                        ' '.join([str(e) for e in t]) for t 
                        in zip(
                            top1_acc_val,
                            top5_acc_val,
                            duration_val)]),
                    os.getenv('PBS_JOBID').split('.')[0],
                    kwargs['email']))
        return


def common(args):
    """Performs setup common to any ST-GCN model variant.
    
    Only needs to be invoked once for a given problem (train-test, benchmark, etc.). 
    Corresponds to the parts of the pipeline irrespective of the black-box model used.
    Creates DataLoaders, sets up processing device and random number generator,
    reads action classes file.

    Args:
        args : ``dict``
            Parsed CLI arguments.

    Returns:
        Dictionary of action classes.

        PyTorch device (CPU or GPU).

        Train and validation DataLoaders.
    """
    
    # setting up random number generator for deterministic and meaningful benchmarking
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True

    # preparing datasets for training and validation
    if args.dataset_type == 'file':
        train_data = SkeletonDataset('{0}/train_data.npy'.format(args.data), '{0}/train_label.pkl'.format(args.data))
        val_data = SkeletonDataset('{0}/val_data.npy'.format(args.data), '{0}/val_label.pkl'.format(args.data))
    elif args.dataset_type == 'dir':
        train_data = SkeletonDatasetFromDirectory('{0}/train/features'.format(args.data), '{0}/train/labels'.format(args.data))
        val_data = SkeletonDatasetFromDirectory('{0}/val/features'.format(args.data), '{0}/val/labels'.format(args.data))

    # trials of different length can not be placed in the same tensor when batching, have to manually iterate over them
    batch_size = 1 if args.dataset_type == 'dir' else args.batch_size
    
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
    
    # extract skeleton graph data
    with open(args.graph, 'r') as graph_file:
        graph = json.load(graph_file)

    # extract actions from the label file
    with open(args.actions, 'r') as action_names:
        actions = action_names.read().split('\n')

    # 0th class is always background action
    actions_dict = {0: "background"}
    for i, action in enumerate(actions):
        actions_dict[i+1] = action

    # prepare a directory to store results
    if not os.getenv('PBS_JOBID'):
        with open('.vscode/pbs_jobid.txt', 'r+') as f:
            job_id = f.readline()
            os.environ['PBS_JOBID'] = job_id
            f.seek(0)
            f.write(str(int(job_id)+1))
            f.truncate()
    
    save_dir = "{0}/{1}/run_{2}".format(args.out, args.model, os.getenv('PBS_JOBID').split('.')[0])
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    return graph, actions_dict, device, train_dataloader, val_dataloader, save_dir


def build_model(args):
    """Builds the selected ST-GCN model variant.
    
    Args:
        args : ``dict``
            Parsed CLI arguments.

    Returns:
        PyTorch Model corresponding to the user-defined CLI parameters.
    
    Raises:
        ValueError: 
            If GCN parameter list sizes do not match the number of stages.
    """

    if (len(args.in_ch) != args.stages or
        len(args.out_ch) != args.stages or
        len(args.stride) != args.stages or
        len(args.residual) != args.stages):
        raise ValueError(
            'GCN parameter list sizes do not match the number of stages. '
            'Check your config file.')
    elif (args.model == 'realtime' and args.buffer != 1):
        raise ValueError(
            'Selected the realtime model, but set buffer size to 1. '
            'Check your config file.')
    
    if args.model == 'original':
        model = OriginalStgcn(**vars(args))
    elif args.model == 'adapted':
        model = AdaptedStgcn(**vars(args))
    else:
        # all 3 adapted versions are encapsulated in the same class, training is identical (batch mode),
        # usecase changes applied during inference
        model = Stgcn(**vars(args))
    
    return model


def train(args):
    """Entry point for training functionality of a single selected model.

    Args:
        args : ``dict``
            Parsed CLI arguments.
    """

    # perform common setup around the model's black box
    args.graph, actions, device, train_dataloader, val_dataloader, save_dir = common(args)
    args.num_classes = len(actions)

    # construct the target model using the CLI arguments
    model = build_model(args)
    # load the checkpoint if not trained from scratch
    if args.checkpoint:
        model.load_state_dict(torch.load(args.checkpoint, map_location=device)['model_state_dict'])
    #     model.load_state_dict({
    #         k.split('module.')[1]: v 
    #         for k, v in
    #         torch.load(args.checkpoint, map_location=device)['model_state_dict'].items()})

    # construct a processing wrapper
    trainer = Processor(model, args.num_classes, train_dataloader, device)

    start_time = time.time()

    # last dimension is the number of subjects in the scene (2 for datasets used)
    print("Training started", flush=True, file=args.log[0])
    
    # perform the training
    # (the model is trained on all skeletons in the scene, simultaneously)
    trainer.train(
        save_dir=save_dir,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,    
        **vars(args))
    
    print("Training completed in: {0}".format(time.time() - start_time), flush=True, file=args.log[0])
    
    os.system(
        'mail -s "[{0}]: $PBS_JOBNAME - COMPLETED" {1} <<< ""'
        .format(
            os.getenv('PBS_JOBID').split('.')[0],
            args.email))

    return


def test(args):
    """Entry point for testing functionality of a single pretrained model.

    Args:
        args : ``dict``
            Parsed CLI arguments.
    """

    # perform common setup around the model's black box
    args.graph, actions, device, val_dataloader, _, save_dir = common(args)
    args.num_classes = len(actions)

    # construct the target model using the CLI arguments
    model = build_model(args)
    # model.load_state_dict(torch.load(args.checkpoint, map_location=device)['model_state_dict'])
    # load the checkpoint if not trained from scratch
    if args.checkpoint:
        model.load_state_dict({
            k.split('module.')[1]: v 
            for k, v in
            torch.load(args.checkpoint, map_location=device)['model_state_dict'].items()})

    # construct a processing wrapper
    trainer = Processor(model, args.num_classes, val_dataloader, device)

    start_time = time.time()

    # last dimension is the number of subjects in the scene (2 for datasets used)
    print("Testing started", flush=True, file=args.log[0])
    
    # perform the testing
    trainer.test(save_dir, val_dataloader, device, **vars(args))
    
    print("Testing completed in: {0}".format(time.time() - start_time), flush=True, file=args.log[0])
    
    os.system(
        'mail -s "[{0}]: $PBS_JOBNAME - COMPLETED" {1} <<< ""'
        .format(
            os.getenv('PBS_JOBID').split('.')[0], 
            args.email))

    return


def benchmark(args):
    """Entry point for benchmarking functionality of multiple pretrained models.

    TODO: complete

    Args:
        args : ``dict``
            Parsed CLI arguments.
    """

    # perform common setup around the model's black box
    args.graph, actions, device, _, val_dataloader, save_dir = common(args)
    args.num_classes = len(actions)
    
    # split between the subjects in the captures
    data, _ = next(iter(val_dataloader))
    args.capture_length = data.shape[2]

    # construct the target models using the CLI arguments
    models = []
    for m in args.models:
        model = build_model(m)
        
        model.load_state_dict({
            k.split('module.')[1]: v 
            for k, v in
            torch.load(args.checkpoint, map_location=device)['model_state_dict'].items()})

        models.append(model)

    # construct a processing wrapper
    trainer = Processor(model, args.num_classes)

    start_time = time.time()

    # last dimension is the number of subjects in the scene (2 for datasets used)
    print("Testing started", flush=True, file=args.log[0])
    
    # perform the testing
    trainer.test(save_dir, val_dataloader, device, **vars(args))

    print("Benchmarking completed in: {0}".format(time.time() - start_time), flush=True, file=args.log[0])
    
    os.system(
        'mail -s "[{0}]: $PBS_JOBNAME - COMPLETED" {1} <<< ""'
        .format(
            os.getenv('PBS_JOBID').split('.')[0],
            args.email))

    return

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.utils.tgcn import ConvTemporalGraphical
from models.utils.graph import Graph
from torch.utils.checkpoint import checkpoint


class Stgcn(nn.Module):
    """Spatial temporal graph convolutional network of Yan, et al. (2018), adapted for realtime.
    (https://arxiv.org/abs/1801.07455).

    Implements both, realtime and buffered realtime logic in the same source. 
    At runtime, the model looks at the L-dimension of the tensor to make the predictions.
    This enforces computations are numerically correct if the frame buffer is not completely full
    (e.g. the last minibatch of frames from a recording if ``capture_length % buffer != 0``).
    
    All arguments are positional to enforce separation of concern and pass the responsibility for
    model configuration up in the chain to the envoking program (config file).

    TODO:
        ``1.`` add logic for variation in FIFO latency.

    Shape:
        - Input[0]:    :math:`(N, C_{in}, L, V)`.
        - Output[0]:   :math:`(N, C_{out}, L)`. 
        
        where
            :math:`N` is a batch size.

            :math:`C_{in}` is the number of input channels (features).

            :math:`C_{out}` is the number of classification classes.

            :math:`L` is the number of frames (capture length).

            :math:`V` is the number of graph nodes.
    """

    def __init__(
        self,
        **kwargs) -> None:
        """
        Kwargs:
            in_feat : ``int`` 
                Number of input sample channels/features.
            
            num_classes : ``int``
                Number of output classification classes.
            
            kernel : ``list[int]``
                Temporal kernel size Gamma.
            
            importance : ``bool``
                If ``True``, adds a learnable importance weighting to the edges of the graph.
            
            latency : ``bool``
                If ``True``, residual connection adds 
                ``x_{t}`` frame to ``y_{t}`` (which adds ``ceil(kernel_size/2)`` latency), 
                adds ``x_{t}`` frame to ``y_{t-ceil(kernel_size/2)}`` otherwise.
            
            layers : ``list[int]``
                Array of number of ST-GCN layers, oner per stage.

            in_ch : ``list[list[int]]``
                2D array of input channel numbers, one per stage per ST-GCN layer.

            out_ch : ``list[list[int]]``
                2D array of output channel numbers, one per stage per ST-GCN layer.

            stride : ``list[list[int]]``
                2D array of temporal stride sizes, one per stage per ST-GCN layer.

            residual : ``list[list[int]]``
                2D array of residual connection flags, one per stage per ST-GCN layer.

            dropout : ``list[list[float]]``
                2D array of dropout parameters, one per stage per ST-GCN layer.

            graph : ``dict`` 
                Dictionary with parameters for skeleton Graph.

            strategy : ``str``
                Type of Graph partitioning strategy.
        """

        super().__init__()
        
        # save the config arguments for model conversions
        self.conf = kwargs

        # verify that parameter dimensions match (correct number of layers/parameters per stage)
        for i, layers_in_stage in enumerate(kwargs['layers']):
            assert((len(kwargs['in_ch'][i]) == layers_in_stage) and
                    (len(kwargs['out_ch'][i]) == layers_in_stage) and
                    (len(kwargs['stride'][i]) == layers_in_stage) and
                    (len(kwargs['residual'][i]) == layers_in_stage),
                ("Incorrect number of constructor parameters in the ST-GCN stage ModuleList.\n"
                "Expected for stage {0}: {1}, got: ({2}, {3}, {4}, {5})")
                .format(
                    i, 
                    kwargs['layers'][i], 
                    len(kwargs['in_ch'][i]), 
                    len(kwargs['out_ch'][i]), 
                    len(kwargs['stride'][i]), 
                    len(kwargs['residual'][i])))

        # register the normalized adjacency matrix as a non-learnable saveable parameter 
        self.graph = Graph(strategy=kwargs['strategy'], **kwargs['graph'])
        A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
        self.register_buffer('A', A)

        # input capture normalization
        # (N,C,L,V)
        self.bn_in = nn.BatchNorm1d(kwargs['in_feat'] * A.size(1))
        
        # fcn for feature remapping of input to the network size
        self.fcn_in = nn.Conv2d(in_channels=kwargs['in_feat'], out_channels=kwargs['in_ch'][0][0], kernel_size=1)
        
        # stack of ST-GCN layers
        stack = [[StgcnLayer(
                    num_joints=kwargs['graph']['num_node'],
                    in_channels=kwargs['in_ch'][i][j],
                    out_channels=kwargs['out_ch'][i][j],
                    kernel_size=kwargs['kernel'][i],
                    stride=kwargs['stride'][i][j],
                    num_partitions=self.A.shape[0],
                    residual=not not kwargs['residual'][i][j],
                    dropout=kwargs['dropout'][i][j])
                for j in range(layers_in_stage)] 
                for i, layers_in_stage in enumerate(kwargs['layers'])]
        # flatten into a single sequence of layers after parameters were used to construct
        # (done like that to make config files more readable)
        self.st_gcn = nn.ModuleList([module for sublist in stack for module in sublist])
        
        # global pooling
        # converts (N,C,L,V) -> (N,C,L,1)
        self.avg_pool = nn.AvgPool2d(kernel_size=(1, kwargs['graph']['num_node']))

        # fcn for prediction
        # maps C to num_classes channels: (N,C,L,1) -> (N,F,L,1) 
        self.fcn_out = nn.Conv2d(
            in_channels=kwargs['out_ch'][-1][-1],
            out_channels=kwargs['num_classes'],
            kernel_size=1)

        # learnable edge importance weighting matrices (each layer, separate weighting)
        if kwargs['importance']:
            self.edge_importance = nn.ParameterList(
                [nn.Parameter(
                    torch.ones(
                        kwargs['graph']['num_node'], 
                        kwargs['graph']['num_node'], 
                        requires_grad=True)) 
                for _ in self.st_gcn])
        else:
            self.edge_importance = [1] * len(self.st_gcn)


    def forward(self, x):
        # data normalization
        N, C, T, V = x.size()
        # permutes must copy the tensor over as contiguous because .view() needs a contiguous tensor
        # this incures extra overhead
        x = x.permute(0, 3, 1, 2).contiguous()
        # (N,V,C,T)
        x = x.view(N, V * C, T)
        x = self.bn_in(x)
        x = x.view(N, V, C, T)
        x = x.permute(0, 2, 3, 1)
        # (N,C,T,V)

        # remap the features to the network size
        x = self.fcn_in(x)

        # feed the frame into the ST-GCN block
        for st_gcn, importance in zip(self.st_gcn, self.edge_importance):
            # adjacency matrix is a 3D tensor (size depends on the partition strategy)
            x = checkpoint(st_gcn, x, self.A * importance)

        # pool the output frame for a single feature vector
        x = self.avg_pool(x)

        # remap the feature vector to class predictions
        x = self.fcn_out(x)

        # removes the last dimension (node dimension) of size 1: (N,C,L,1) -> (N,C,L)
        return x.squeeze(-1)


    # def _swap_layers_for_inference(self: nn.Module) -> nn.Module:
        
    #     return


    # def train(self: nn.Module, mode: bool = True) -> nn.Module:
    #     # TODO: 
    #     return super().train(mode)

    
    # def eval(self: nn.Module) -> nn.Module:
    #     super().eval()

    #     # stack of ST-GCN layers
    #     stack = [[RtStgcnLayer(
    #                 num_joints=self.conf['graph']['num_node'],
    #                 fifo_latency=self.conf['latency'],
    #                 in_channels=self.conf['in_ch'][i][j],
    #                 out_channels=self.conf['out_ch'][i][j],
    #                 kernel_size=self.conf['kernel'][i],
    #                 stride=self.conf['stride'][i][j],
    #                 num_partitions=self.A.shape[0],
    #                 residual=not not self.conf['residual'][i][j],
    #                 dropout=self.conf['dropout'][i][j],
    #                 **self.conf)
    #             for j in range(layers_in_stage)] 
    #             for i, layers_in_stage in enumerate(self.conf['layers'])]
    #     # flatten into a single sequence of layers after parameters were used to construct
    #     # (done like that to make config files more readable)
    #     new_st_gcn = nn.ModuleList([module for sublist in stack for module in sublist])

    #     # TODO: copy trained weights over from batch training to the inference layers
    #     # for self.parameters()

    #     return self

    
class RtStgcnLayer(nn.Module):
    """[!Inference only!] Applies a spatial temporal graph convolution over an input graph sequence.
    
    Each layer has a FIFO to store the corresponding Gamma-sized window of graph frames.
    All arguments are positional to enforce separation of concern and pass the responsibility for
    model configuration up in the chain to the envoking program (config file).

    TODO:
        ``1.`` add logic for variation in FIFO latency.

        ``2.`` write more elaborate class description about the design choices and working principle.

    Shape:
        - Input[0]:     :math:`(N, C_{in}, L, V)` - Input graph frame.
        - Input[1]:     :math:`(P, V, V)` - Graph adjacency matrix.
        - Output[0]:    :math:`(N, C_{out}, L, V)` - Output graph frame.

        where
            :math:`N` is the batch size.
            
            :math:`L` is the buffer size (buffered frames number).

            :math:`C_{in}` is the number of input channels (features).

            :math:`C_{out}` is the number of output channels (features).

            :math:`V` is the number of graph nodes.

            :math:`P` is the number of graph partitions.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        num_joints,
        stride,
        num_partitions,
        dropout,
        residual,
        fifo_latency,
        **kwargs):
        """
        Args:
            in_channels : ``int``
                Number of input sample channels/features.
            
            out_channels : ``int``
                Number of channels produced by the convolution.
            
            kernel_size : ``int``
                Size of the temporal window Gamma.
            
            num_joints : ``int``
                Number of joint nodes in the graph.
            
            stride : ``int``
                Stride of the temporal reduction.
            
            num_partitions : ``int``
                Number of partitions in selected strategy.
                Must correspond to the first dimension of the adjacency tensor.
            
            dropout : ``float``
                Dropout rate of the final output.
            
            residual : ``bool``
                If ``True``, applies a residual connection.
            
            fifo_latency : ``bool``
                If ``True``, residual connection adds ``x_{t}`` frame to ``y_{t}`` (which adds ``ceil(kernel_size/2)`` latency), 
                otherwise adds ``x_{t}`` frame to ``y_{t-ceil(kernel_size/2)}``.
        """
        
        super().__init__()

        # temporal kernel Gamma is symmetric (odd number)
        # assert len(kernel_size) == 1
        assert kernel_size % 2 == 1

        self.num_partitions = num_partitions
        self.num_joints = num_joints
        self.stride = stride

        self.out_channels = out_channels
        self.fifo_size = stride*(kernel_size-1)+1
        
        # convolution of incoming frame 
        # (out_channels is a multiple of the partition number
        # to avoid for-looping over several partitions)
        # partition-wise convolution results are basically stacked across channel-dimension
        self.conv = nn.Conv2d(in_channels, out_channels*num_partitions, kernel_size=1)

        # FIFO for intermediate Gamma graph frames after multiplication with adjacency matrices
        # (N,G,P,C,V) - (N)batch, (G)amma, (P)artition, (C)hannels, (V)ertices
        self.fifo = torch.zeros(kwargs['batch_size'], self.fifo_size, num_partitions, out_channels, num_joints)
        
        # normalization and dropout on main branch
        self.bn_do = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout, inplace=True))

        # residual branch
        if not residual:
            self.residual = lambda _: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels))

        # activation of branch sum
        self.relu = nn.ReLU(inplace=True)


    def forward(self, x, A):
        """
        In case of buffered realtime processing, Conv2D and MMUL are done on the buffered frames,
        which mimics the kernels reuse mechanism that would be followed in hardware at the expense
        of extra memory for storing intermediate results.

        TODO:
            ``1.`` Speed up the for-loop part (buffered realtime setup) by vectorization.
        """
        # residual branch
        res = self.residual(x)
        
        # spatial convolution of incoming frame (node-wise)
        a = self.conv(x)

        # convert to the expected dimension order and add the partition dimension
        # reshape the tensor for multiplication with the adjacency matrix
        # (convolution output contains all partitions, stacked across the channel dimension)
        # split into separate 4D tensors, each corresponding to a separate partition
        b = torch.split(a, self.out_channels, dim=1)
        # concatenate these 4D tensors across the partition dimension
        c = torch.stack(b, -1)
        # change the dimension order for the correct broadcating of the adjacency matrix
        # (N,C,L,V,P) -> (N,L,P,C,V)
        d = c.permute(0,2,4,1,3)
        # single multiplication with the adjacency matrices (spatial selective addition, across partitions)
        e = torch.matmul(d, A)

        # perform temporal accumulation for each of the buffered frames
        # (portability for buffered_realtime setup, for realtime, the buffer is of size 1)
        outputs = []
        for i in range(e.shape[1]):
            # push the frame into the FIFO
            self.fifo = torch.cat((e[:,i:i+1], self.fifo[:,:self.fifo_size-1]), 1)
            
            # slice the tensor according to the temporal stride size
            # (if stride is 1, returns the whole tensor itself)
            f = self.fifo[:,range(0, self.fifo_size, self.stride)]

            # sum temporally and across partitions
            # (C,H)
            g = torch.sum(f, dim=(1,2))
            outputs.append(g)

        # stack frame-wise tensors into the original length L
        # [(N,C,V)] -> (N,C,L,V)
        h = torch.stack(outputs, 2)

        # add the branches (main + residual)
        i = h + res

        return self.relu(i)


class StgcnLayer(nn.Module):
    """[Training] Applies a spatial temporal graph convolution over an input graph sequence.
    
    Processes the entire video capture during training; it is mandatory to retain intermediate values
    for backpropagation (hence no FIFOs allowed in training). Results of training with either layer
    are identical, it is simply a nuissance of autodiff frameworks.
    All arguments are positional to enforce separation of concern and pass the responsibility for
    model configuration up in the chain to the envoking program (config file).

    TODO:
        ``1.`` validate documentation.

    Shape:
        - Input[0]:     :math:`(N, C_{in}, L, V)` - Input graph frame.
        - Input[1]:     :math:`(P, V, V)` - Graph adjacency matrix.
        - Output[0]:    :math:`(N, C_{out}, L, V)` - Output graph frame.

        where
            :math:`N` is the batch size.

            :math:`C_{in}` is the number of input channels (features).

            :math:`C_{out}` is the number of output channels (features).

            :math:`L` is the video capture length.

            :math:`V` is the number of graph nodes.

            :math:`P` is the number of graph partitions.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        num_joints,
        stride,
        num_partitions,
        dropout,
        residual):
        """
        Args:
            in_channels : ``int``
                Number of input sample channels/features.
            
            out_channels : ``int``
                Number of channels produced by the convolution.
            
            kernel_size : ``int``
                Size of the temporal window Gamma.
            
            num_joints : ``int``
                Number of joint nodes in the graph.
            
            stride : ``int``
                Stride of the temporal reduction.
            
            num_partitions : ``int``
                Number of partitions in selected strategy.
                Must correspond to the first dimension of the adjacency tensor.
            
            dropout : ``float``
                Dropout rate of the final output.
            
            residual : ``bool``
                If ``True``, applies a residual connection.
        """
        
        super().__init__()

        # temporal kernel Gamma is symmetric (odd number)
        # assert len(kernel_size) == 1
        assert kernel_size % 2 == 1

        self.num_partitions = num_partitions
        self.num_joints = num_joints
        self.stride = stride
        self.kernel_size = kernel_size

        self.out_channels = out_channels

        # convolution of incoming frame 
        # (out_channels is a multiple of the partition number
        # to avoid for-looping over several partitions)
        # partition-wise convolution results are basically stacked across channel-dimension
        self.conv = nn.Conv2d(in_channels, out_channels*num_partitions, kernel_size=1, bias=False)
        
        # normalization and dropout on main branch
        self.bn_relu = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

        # residual branch
        if not residual:
            self.residual = lambda _: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels))

        # activation of branch sum
        # if no resnet connection, prevent ReLU from being applied twice
        if not residual:
            self.do = nn.Dropout(dropout)
        else:
            self.do = nn.Sequential(
                nn.ReLU(),
                nn.Dropout(dropout))


    def forward(self, x, A):
        # TODO: replace with unfold -> fold calls
        # lower triangle matrix for temporal accumulation that mimics FIFO behavior
        capture_length = x.size(2)
        device = torch.device("cuda:{0}".format(torch.cuda.current_device()) if torch.cuda.is_available() else "cpu")
        lt_matrix = torch.zeros(capture_length, capture_length, device=device)
        for i in range(self.kernel_size//self.stride):
            lt_matrix += F.pad(
                torch.eye(
                    capture_length - self.stride * i,
                    device=device),
                (i*self.stride,0,0,i*self.stride))
        # must register matrix as a buffer to automatically move to GPU with model.to_device()
        # for PyTorch v1.0.1
        # self.register_buffer('lt_matrix', lt_matrix)

        # residual branch 
        res = self.residual(x) 
         
        # spatial convolution of incoming frame (node-wise) 
        x = self.conv(x) 
 
        # convert to the expected dimension order and add the partition dimension 
        # reshape the tensor for multiplication with the adjacency matrix 
        # (convolution output contains all partitions, stacked across the channel dimension) 
        # split into separate 4D tensors, each corresponding to a separate partition 
        x = torch.split(x, self.out_channels, dim=1) 
        # concatenate these 4D tensors across the partition dimension 
        x = torch.stack(x, -1) 
        # change the dimension order for the correct broadcating of the adjacency matrix 
        # (N,C,L,V,P) -> (N,L,P,C,V) 
        x = x.permute(0,2,4,1,3) 
        # single multiplication with the adjacency matrices (spatial selective addition, across partitions) 
        x = torch.matmul(x, A) 
 
        # sum temporally by multiplying features with the Toeplitz matrix 
        # reorder dimensions for correct broadcasted multiplication (N,L,P,C,V) -> (N,P,C,V,L) 
        x = x.permute(0,2,3,4,1) 
        x = torch.matmul(x, lt_matrix) 
        # sum across partitions (N,C,V,L) 
        x = torch.sum(x, dim=(1)) 
        # match the dimension ordering of the input (N,C,V,L) -> (N,C,L,V) 
        x = x.permute(0,1,3,2) 
 
        # normalize the output of the st-gcn operation and activate 
        x = self.bn_relu(x) 
 
        # add the branches (main + residual), activate and dropout 
        return self.do(x + res) 



class OriginalStgcn(nn.Module):
    """Original classification spatial temporal graph convolutional networks.

    Data provision (batching, unfolding, etc.) is delegated to the caller. Model operates 
    on frame-by-frame basis and only requires an input buffer supplied to in the size of
    the requested receptive field.

    Args:
        in_channels (int): Number of channels in the input data
        num_class (int): Number of classes for the classification task
        graph_args (dict): The arguments for building the graph
        edge_importance_weighting (bool): If ``True``, adds a learnable
            importance weighting to the edges of the graph
        **kwargs (optional): Other parameters for graph convolution units
    Shape:
        - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
        - Output: :math:`(N, num_class)` where
            :math:`N` is a batch size,
            :math:`T_{in}` is a length of input sequence,
            :math:`V_{in}` is the number of graph nodes,
            :math:`M_{in}` is the number of instance in a frame.
    """

    def __init__(self, **kwargs):
        super().__init__()

        self.conf = kwargs

        # load graph
        self.graph = Graph(strategy=kwargs['strategy'], **kwargs['graph'])
        A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
        self.register_buffer('A', A)

        # build networks
        spatial_kernel_size = kwargs['graph']['num_node']
        temporal_kernel_size = kwargs['kernel'][0]
        kernel_size = (temporal_kernel_size, spatial_kernel_size)
        
        self.data_bn = nn.BatchNorm1d(kwargs['in_feat'] * A.size(1))
        # fcn for feature remapping of input to the network size
        self.fcn_in = nn.Conv2d(in_channels=kwargs['in_feat'], out_channels=kwargs['in_ch'][0][0], kernel_size=1)

        stack = [[st_gcn_original(
                    in_channels=kwargs['in_ch'][i][j],
                    out_channels=kwargs['out_ch'][i][j],
                    kernel_size=kernel_size,
                    partitions=A.size(0),
                    stride=kwargs['stride'][i][j],
                    residual=not not kwargs['residual'][i][j],
                    dropout=kwargs['dropout'][i][j])
                for j in range(layers_in_stage)] 
                for i, layers_in_stage in enumerate(kwargs['layers'])]
        self.st_gcn_networks = nn.ModuleList([module for sublist in stack for module in sublist])

        # initialize parameters for edge importance weighting
        if kwargs['importance']:
            self.edge_importance = nn.ParameterList([
                nn.Parameter(torch.ones(self.A.size()))
                for _ in self.st_gcn_networks
            ])
        else:
            self.edge_importance = [1] * len(self.st_gcn_networks)

        # fcn for prediction
        self.fcn = nn.Conv2d(
            kwargs['out_ch'][-1][-1],
            out_channels=kwargs['num_classes'],
            kernel_size=1)


    def forward(self, x):
        # data normalization
        N, C, T, V = x.size()
        # permutes must copy the tensor over as contiguous because .view() needs a contiguous tensor
        # this incures extra overhead
        x = x.permute(0, 3, 1, 2).contiguous()
        # (N,V,C,T)
        x = x.view(N, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, V, C, T)
        x = x.permute(0, 2, 3, 1)
        # (N,C,T,V)

        # remap the features to the network size
        x = self.fcn_in(x)

        # forward
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x = checkpoint(gcn, x, self.A * importance)

        # global pooling (across time L, and nodes V)
        x = F.avg_pool2d(x, x.size()[2:])

        # prediction
        x = self.fcn(x)

        return x


    def extract_feature(self, x):
        # data normalization
        N, C, T, V, M = x.size()
        # permutes must copy the tensor over as contiguous because .view() needs a contiguous tensor
        # this incures extra overhead
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        # (N,M,V,C,T)
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        x = x.view(N * M, C, T, V)
        # (N',C,T,V)

        # remap the features to the network size
        x = self.fcn_in(x)

        # forward
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x, _ = gcn(x, self.A * importance)

        # global pooling (across time L, and nodes V)
        x = F.avg_pool2d(x, x.size()[2:])

        feature = x.squeeze(-1)

        # prediction
        x = self.fcn(x)
        output = x.squeeze(-1)

        return output, feature


class st_gcn_original(nn.Module):
    """Applies a spatial temporal graph convolution over an input graph sequence.
    Args:
        in_channels (int): Number of channels in the input sequence data
        out_channels (int): Number of channels produced by the convolution
        kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
        stride (int, optional): Stride of the temporal convolution. Default: 1
        dropout (int, optional): Dropout rate of the final output. Default: 0
        residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        partitions,
        stride=1,
        dropout=0,
        residual=True):
        
        super().__init__()

        assert len(kernel_size) == 2
        assert kernel_size[0] % 2 == 1
        padding = (((kernel_size[0] - 1) // 2), 0)

        self.gcn = ConvTemporalGraphical(
            in_channels, 
            out_channels,
            kernel_size[1],
            partitions)

        self.tcn = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                (kernel_size[0], 1),
                stride=(stride, 1),
                padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout, inplace=True))

        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=(stride, 1)),
                nn.BatchNorm2d(out_channels))

        self.relu = nn.ReLU(inplace=True)


    def forward(self, x, A):
        res = self.residual(x)
        # graph convolution
        x = self.gcn(x, A)
        # temporal accumulation (but using a learnable kernel)
        x = self.tcn(x)

        return self.relu(x + res)


class AdaptedStgcn(nn.Module):
    """Spatial temporal graph convolutional networks, adapted for segmentation.

    Args:
        in_channels (int): Number of channels in the input data
        num_class (int): Number of classes for the classification task
        graph_args (dict): The arguments for building the graph
        edge_importance_weighting (bool): If ``True``, adds a learnable
            importance weighting to the edges of the graph
        **kwargs (optional): Other parameters for graph convolution units
    Shape:
        - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
        - Output: :math:`(N, num_class)` where
            :math:`N` is a batch size,
            :math:`T_{in}` is a length of input sequence,
            :math:`V_{in}` is the number of graph nodes,
            :math:`M_{in}` is the number of instance in a frame.
    """

    def __init__(self, **kwargs):
        super().__init__()

        # load graph
        self.graph = Graph(strategy=kwargs['strategy'], **kwargs['graph'])
        A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
        self.register_buffer('A', A)

        # build networks
        spatial_kernel_size = kwargs['graph']['num_node']
        temporal_kernel_size = kwargs['kernel'][0]
        kernel_size = (temporal_kernel_size, spatial_kernel_size)
        
        self.data_bn = nn.BatchNorm1d(kwargs['in_feat'] * A.size(1))
        # fcn for feature remapping of input to the network size
        self.fcn_in = nn.Conv2d(in_channels=kwargs['in_feat'], out_channels=kwargs['in_ch'][0][0], kernel_size=1)

        stack = [[st_gcn_adapted(
                    in_channels=kwargs['in_ch'][i][j],
                    out_channels=kwargs['out_ch'][i][j],
                    kernel_size=kernel_size,
                    partitions=A.size(0),
                    stride=kwargs['stride'][i][j],
                    residual=not not kwargs['residual'][i][j],
                    dropout=kwargs['dropout'][i][j])
                for j in range(layers_in_stage)] 
                for i, layers_in_stage in enumerate(kwargs['layers'])]
        self.st_gcn_networks = nn.ModuleList([module for sublist in stack for module in sublist])

        # initialize parameters for edge importance weighting
        if kwargs['importance']:
            self.edge_importance = nn.ParameterList([
                nn.Parameter(torch.ones(self.A.size()))
                for _ in self.st_gcn_networks
            ])
        else:
            self.edge_importance = [1] * len(self.st_gcn_networks)

        # fcn for prediction
        self.fcn = nn.Conv2d(
            kwargs['out_ch'][-1][-1], 
            out_channels=kwargs['num_classes'],
            kernel_size=1)


    def forward(self, x):
        # data normalization
        N, C, T, V = x.size()
        # permutes must copy the tensor over as contiguous because .view() needs a contiguous tensor
        # this incures extra overhead
        x = x.permute(0, 3, 1, 2).contiguous()
        # (N,V,C,T)
        x = x.view(N, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, V, C, T)
        x = x.permute(0, 2, 3, 1)
        # (N,C,T,V)

        # remap the features to the network size
        x = self.fcn_in(x)

        # forward
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x = checkpoint(gcn, x, self.A * importance)

        # global pooling
        x = F.avg_pool2d(x, (1, x.size()[-1]))

        # prediction
        x = self.fcn(x)
        x = x.squeeze(-1)

        return x


    def extract_feature(self, x):
        # data normalization
        N, C, T, V, M = x.size()
        # permutes must copy the tensor over as contiguous because .view() needs a contiguous tensor
        # this incures extra overhead
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        # (N,M,V,C,T)
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        x = x.view(N * M, C, T, V)
        # (N',C,T,V)

        # remap the features to the network size
        x = self.fcn_in(x)

        # forward
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x = checkpoint(gcn, x, self.A * importance)

        # global pooling
        x = F.avg_pool2d(x, (1, x.size()[-1]))

        feature = x.squeeze(-1)

        # prediction
        x = self.fcn(x)
        output = x.squeeze(-1)

        return output, feature


class st_gcn_adapted(nn.Module):
    """Applies a spatial temporal graph convolution over an input graph sequence.
    Args:
        in_channels (int): Number of channels in the input sequence data
        out_channels (int): Number of channels produced by the convolution
        kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
        stride (int, optional): Stride of the temporal convolution. Default: 1
        dropout (int, optional): Dropout rate of the final output. Default: 0
        residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        partitions,
        stride=1,
        dropout=0,
        residual=True):
        
        super().__init__()

        assert len(kernel_size) == 2
        assert kernel_size[0] % 2 == 1
        padding = (((kernel_size[0] - 1) // 2) * stride, 0)

        self.gcn = ConvTemporalGraphical(
            in_channels, 
            out_channels,
            kernel_size[1],
            partitions)

        self.tcn = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                (kernel_size[0], 1),
                dilation=(stride, 1),
                padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout, inplace=True))

        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=1),
                nn.BatchNorm2d(out_channels))

        self.relu = nn.ReLU(inplace=True)


    def forward(self, x, A):
        res = self.residual(x)
        # graph convolution
        x = self.gcn(x, A)
        # temporal accumulation (but using a learnable kernel)
        x = self.tcn(x)

        return self.relu(x + res)
        

  assert((len(kwargs['in_ch'][i]) == layers_in_stage) and


In [3]:
# top-level custom CLI parser
parser = st_gcn_parser.Parser(
    prog='main',
    description='Script for human action segmentation processing using ST-GCN networks.',
    epilog='TODO: add the epilog')

subparsers= parser.add_subparsers(metavar='command')

# train command parser (must manually update usage after changes 
# to the argument list or provide a custom formatter)
parser_train = subparsers.add_parser(
    'train',
    usage="""%(prog)s [-h]
        \r\t[--config FILE]            
        \r\t[--model MODEL {realtime|buffer_realtime|batch|original}]
        \r\t[--strategy STRATEGY {uniform|distance|spatial}]
        \r\t[--in_feat IN_FEAT]
        \r\t[--stages STAGES]
        \r\t[--buffer BUFFER]
        \r\t[--kernel [KERNEL]]
        \r\t[--segment [SEGMENT]]
        \r\t[--importance]
        \r\t[--latency]
        \r\t[--receptive_field FIELD]
        \r\t[--layers [LAYERS]]
        \r\t[--in_ch [IN_CH,[...]]]
        \r\t[--out_ch [OUT_CH,[...]]]
        \r\t[--stride [STRIDE,[...]]]
        \r\t[--residual [RESIDUAL,[...]]]
        \r\t[--dropout [DROPOUT,[...]]]
        \r\t[--graph FILE]

        \r\t[--seed SEED]
        \r\t[--epochs EPOCHS]
        \r\t[--checkpoints [CHECKPOINTS]]
        \r\t[--learning_rate RATE]
        \r\t[--learning_rate_decay RATE_DECAY]
        \r\t[--batch_size BATCH]

        \r\t[--data DATA_DIR]
        \r\t[--dataset_type TYPE]
        \r\t[--actions FILE]
        \r\t[--out OUT_DIR]
        \r\t[--checkpoint CHECKPOINT]
        \r\t[--log O_FILE E_FILE]
        \r\t[--email EMAIL]
        \r\t[-v[vv]]""",
    help='train target ST-GCN network',
    epilog='TODO: add the epilog')

parser_train_model = parser_train.add_argument_group(
    'model',
    'arguments for configuring the ST-GCN model. '
    'If an argument is not provided, defaults to value inside config file. '
    'User can provide own config JSON file using --config argument, '
    'but it is the user\'s responsibility to provide all needed parameters')
parser_train_optim = parser_train.add_argument_group(
    'optimizer',
    'arguments for configuring training')
parser_train_io = parser_train.add_argument_group(
    'IO',
    'all miscallenous IO, log, file and path arguments')

# model arguments
parser_train_model.add_argument(
    '--config',
    type=str,
    default='config/kinetics/realtime_local.json',
    metavar='',
    help='path to the NN config file. Must be the last argument if combined '
        'with other CLI arguments. Provides default values for all arguments, except --log '
        '(default: config/kinetics/realtime_local.json)')
parser_train_model.add_argument(
    '--model',
    choices=['realtime','buffer_realtime','batch','original'],
    metavar='',
    help='type of NN model to use (default: realtime)')
parser_train_model.add_argument(
    '--strategy',
    choices=['uniform','distance','spatial'],
    metavar='',
    help='type of graph partitioning strategy to use (default: spatial)')
parser_train_model.add_argument(
    '--in_feat',
    type=int,
    metavar='',
    help='number of features/channels in data samples (default: 3)')
parser_train_model.add_argument(
    '--stages',
    type=int,
    metavar='',
    help='number of ST-GCN stages to stack (default: 1)')
parser_train_model.add_argument(
    '--buffer',
    type=int,
    metavar='',
    help='number of frames to buffer before batch processing. '
        'Applied only when --model=buffer_realtime (default: 1)')
parser_train_model.add_argument(
    '--kernel',
    type=int,
    nargs='+',
    metavar='',
    help='list of temporal kernel sizes (Gamma) per stage (default: [9])')
parser_train_model.add_argument(
    '--segment',
    type=int,
    metavar='',
    help='size of overlapping segments of frames to divide a trial into for '
        'parallelizing computation (creates a new batch dimension). '
        'Currently only supports datasets with different length trials. '
        'Applied only when --model != original and --dataset_type=dir (default: 100)')
parser_train_model.add_argument(
    '--importance',
    default=True,
    action='store_true',
    help='flag specifying whether ST-GCN layers have edge importance weighting '
        '(default: True)')
parser_train_model.add_argument(
    '--latency',
    default=False,
    action='store_true',
    help='flag specifying whether ST-GCN layers have half-buffer latency, '
        'or non-overlapping window when --model=original (default: False)')
parser_train_model.add_argument(
    '--receptive_field',
    type=int,
    metavar='',
    help='number of frames in a sliding window across raw inputs. '
        'Applied only when --model=original (default: 50)')
parser_train_model.add_argument(
    '--layers',
    type=int,
    nargs='+',
    metavar='',
    help='list of number of ST-GCN layers per stage (default: [9])')
parser_train_model.add_argument(
    '--in_ch',
    type=int,
    nargs='+',
    action='append',
    metavar='',
    help='list of number of input channels per ST-GCN layer per stage. '
        'For multi-stage, pass --in_ch parameter multiple times '
        '(default: [[64,64,64,64,128,128,128,256,256]])')
parser_train_model.add_argument(
    '--out_ch',
    type=int, 
    nargs='+',
    action='append',
    metavar='',
    help='list of number of output channels per ST-GCN layer per stage. '
        'For multi-stage, pass --out_ch parameter multiple times '
        '(default: [[64,64,64,128,128,128,256,256,256]])')
parser_train_model.add_argument(
    '--stride',
    type=int, 
    nargs='+',
    action='append',
    metavar='',
    help='list of size of stride in temporal accumulation per ST-GCN layer per stage. '
        'For multi-stage, pass --stride parameter multiple times '
        '(default: [[1,1,1,2,1,1,2,1,1]])')
parser_train_model.add_argument(
    '--residual',
    type=int, 
    nargs='+',
    action='append',
    metavar='',
    help='list of binary flags specifying residual connection per ST-GCN layer per stage. '
        'For multi-stage, pass --residual parameter multiple times '
        '(default: [[0,1,1,1,1,1,1,1,1]])')
parser_train_model.add_argument(
    '--dropout',
    type=float,
    nargs='+',
    action='append',
    metavar='',
    help='list of dropout values per ST-GCN layer per stage. '
        'For multi-stage, pass --dropout parameter multiple times '
        '(default: [[0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]])')
parser_train_model.add_argument(
    '--graph',
    type=str,
    metavar='',
    help='path to the skeleton graph specification file '
        '(default: data/skeletons/openpose.json)')
# optimizer arguments
parser_train_optim.add_argument(
    '--seed',
    type=int,
    metavar='',
    help='seed for the random number generator (default: 1538574472)')
parser_train_optim.add_argument(
    '--epochs',
    type=int,
    metavar='',
    help='number of epochs to train the NN over (default: 100)')
parser_train_optim.add_argument(
    '--checkpoints',
    type=int,
    nargs='+',
    metavar='',
    help='list of epochs to checkpoint the model at '
        '(default: [19, 39, 59, 79, 99])')
parser_train_optim.add_argument(
    '--learning_rate',
    type=float,
    metavar='',
    help='learning rate of the optimizer (default: 0.01)')
parser_train_optim.add_argument(
    '--learning_rate_decay',
    type=float,
    metavar='',
    help='learning rate decay factor of the optimizer (default: 0.1)')
parser_train_optim.add_argument(
    '--batch_size',
    type=int,
    metavar='',
    help='number of captures to process in a minibatch (default: 16)')
# IO arguments
parser_train_io.add_argument(
    '--data',
    metavar='',
    help='path to the dataset directory (default: data/kinetics)')
parser_train_io.add_argument(
    '--dataset_type',
    metavar='',
    help='type of the dataset (default: file)')
parser_train_io.add_argument(
    '--actions',
    metavar='',
    help='path to the action classes file (default: data/kinetics/actions.txt)')
parser_train_io.add_argument(
    '--out',
    metavar='',
    help='path to the output directory (default: pretrained_models/kinetics)')
parser_train_io.add_argument(
    '--checkpoint',
    type=str,
    metavar='',
    default=None,
    help='path to the checkpoint to restore states from (default: None)')
parser_train_io.add_argument(
    '--log',
    nargs=2,
    type=argparse.FileType('w'),
    # const=[t1+t2+'.txt' for t1, t2 in zip(['log.o.','log.e.'],2*[str(time.time())])],
    default=[sys.stdout, sys.stderr],
    metavar='',
    help='files to log the script to. Only argument without default option in --config '
        '(default: stdout, stderr)')
parser_train_io.add_argument(
    '--email',
    type=str,
    metavar='',
    default=None,
    help='email address to send update notifications to (default: None)')
parser_train_io.add_argument(
    '-v', '--verbose', dest='verbose',
    action='count', 
    default=0,
    help='level of log detail (default: 0)')

# test command parser
parser_test = subparsers.add_parser(
    'test',
    usage="""%(prog)s\n\t[-h]
        \r\t[--config FILE]            
        \r\t[--model MODEL {realtime|buffer_realtime|batch|original}]
        \r\t[--strategy STRATEGY {uniform|distance|spatial}]
        \r\t[--in_feat IN_FEAT]
        \r\t[--stages STAGES]
        \r\t[--buffer BUFFER]
        \r\t[--kernel [KERNEL]]
        \r\t[--importance]
        \r\t[--latency]
        \r\t[--layers [LAYERS]]
        \r\t[--in_ch [IN_CH,[...]]]
        \r\t[--out_ch [OUT_CH,[...]]]
        \r\t[--stride [STRIDE,[...]]]
        \r\t[--residual [RESIDUAL,[...]]]
        \r\t[--dropout [DROPOUT,[...]]]
        \r\t[--graph FILE]

        \r\t[--data DATA_DIR]
        \r\t[--dataset_type TYPE]
        \r\t[--actions FILE]
        \r\t[--out OUT_DIR]
        \r\t[--checkpoint CHECKPOINT]
        \r\t[--log O_FILE E_FILE]
        \r\t[--email EMAIL]
        \r\t[-v[vv]]""",
    help='test target ST-GCN network',
    epilog='TODO: add the epilog')

parser_test_model = parser_test.add_argument_group(
    'model',
    'arguments for configuring the ST-GCN model. '
    'If an argument is not provided, defaults to value inside config file. '
    'User can provide own config JSON file using --config argument, '
    'but it is the user\'s responsibility to provide all needed parameters')
parser_test_io = parser_test.add_argument_group(
    'IO',
    'all miscallenous IO, log, file and path arguments')

# model arguments
parser_test_model.add_argument(
    '--config',
    type=str,
    default='config/kinetics/realtime_local.json',
    metavar='',
    help='path to the NN config file. Must be the last argument if combined '
        'with other CLI arguments. Provides default values for all arguments, except --log '
        '(default: config/kinetics/realtime_local.json)')
parser_test_model.add_argument(
    '--model',
    choices=['realtime','buffer_realtime','batch','original'],
    metavar='',
    help='type of NN model to use (default: realtime)')
parser_test_model.add_argument(
    '--strategy',
    choices=['uniform','distance','spatial'],
    metavar='',
    help='type of graph partitioning strategy to use (default: spatial)')
parser_test_model.add_argument(
    '--in_feat',
    type=int,
    metavar='',
    help='number of features/channels in data samples (default: 3)')
parser_test_model.add_argument(
    '--stages',
    type=int,
    metavar='',
    help='number of ST-GCN stages to stack (default: 1)')
parser_test_model.add_argument(
    '--buffer',
    type=int,
    metavar='',
    help='number of frames to buffer before batch processing. '
        'Applied only when --model=buffer_realtime (default: 1)')
parser_test_model.add_argument(
    '--kernel',
    type=int,
    nargs='+',
    metavar='',
    help='list of temporal kernel sizes (Gamma) per stage (default: [9])')
parser_test_model.add_argument(
    '--importance',
    default=True,
    action='store_true',
    help='flag specifying whether ST-GCN layers have edge importance weighting '
        '(default: True)')
parser_test_model.add_argument(
    '--latency',
    default=False,
    action='store_true',
    help='flag specifying whether ST-GCN layers have half-buffer latency '
        '(default: False)')
parser_test_model.add_argument(
    '--layers',
    type=int,
    nargs='+',
    metavar='',
    help='list of number of ST-GCN layers per stage (default: [9])')
parser_test_model.add_argument(
    '--in_ch',
    type=int,
    nargs='+',
    action='append',
    metavar='',
    help='list of number of input channels per ST-GCN layer per stage. '
        'For multi-stage, pass --in_ch parameter multiple times '
        '(default: [[64,64,64,64,128,128,128,256,256]])')
parser_test_model.add_argument(
    '--out_ch',
    type=int, 
    nargs='+',
    action='append',
    metavar='',
    help='list of number of output channels per ST-GCN layer per stage. '
        'For multi-stage, pass --out_ch parameter multiple times '
        '(default: [[64,64,64,128,128,128,256,256,256]])')
parser_test_model.add_argument(
    '--stride',
    type=int, 
    nargs='+',
    action='append',
    metavar='',
    help='list of size of stride in temporal accumulation per ST-GCN layer per stage. '
        'For multi-stage, pass --stride parameter multiple times '
        '(default: [[1,1,1,2,1,1,2,1,1]])')
parser_test_model.add_argument(
    '--residual',
    type=int, 
    nargs='+',
    action='append',
    metavar='',
    help='list of binary flags specifying residual connection per ST-GCN layer per stage. '
        'For multi-stage, pass --residual parameter multiple times '
        '(default: [[0,1,1,1,1,1,1,1,1]])')
parser_test_model.add_argument(
    '--dropout',
    type=float,
    nargs='+',
    action='append',
    metavar='',
    help='list of dropout values per ST-GCN layer per stage. '
        'For multi-stage, pass --dropout parameter multiple times '
        '(default: [[0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]])')
parser_test_model.add_argument(
    '--graph',
    type=str,
    metavar='',
    help='path to the skeleton graph specification file '
        '(default: data/skeletons/openpose.json)')
# IO arguments
parser_test_io.add_argument(
    '--data',
    metavar='',
    help='path to the dataset directory (default: data/kinetics)')
parser_test_io.add_argument(
    '--dataset_type',
    metavar='',
    help='type of the dataset (default: file)')
parser_test_io.add_argument(
    '--actions',
    metavar='',
    help='path to the action classes file (default: data/kinetics/actions.txt)')
parser_test_io.add_argument(
    '--out',
    metavar='',
    help='path to the output directory (default: pretrained_models/kinetics)')
parser_test_io.add_argument(
    '--checkpoint',
    type=str,
    metavar='',
    default=None,
    help='path to the checkpoint to restore states from (default: None)')
parser_test_io.add_argument(
    '--log',
    nargs=2,
    type=argparse.FileType('w'),
    # const=[t1+t2+'.txt' for t1, t2 in zip(['log.o.','log.e.'],2*[str(time.time())])],
    default=[sys.stdout, sys.stderr],
    metavar='',
    help='files to log the script to. Only argument without default option in --config '
        '(default: stdout, stderr)')
parser_test_io.add_argument(
    '--email',
    type=str,
    metavar='',
    default=None,
    help='email address to send update notifications to (default: None)')
parser_test_io.add_argument(
    '-v', '--verbose', dest='verbose',
    action='count', 
    default=0,
    help='level of log detail (default: 0)')

##################################################################
# benchmark command parser
# TODO: setup all the needed CLI arguments
parser_benchmark = subparsers.add_parser(
    'benchmark',
    usage="""%(prog)s\n\t[-h]

        """,
    help='benchmark target ST-GCN network against baseline(s)',
    epilog='TODO: add the epilog')
##################################################################

parser_train.set_defaults(func=train)
parser_test.set_defaults(func=test)
parser_benchmark.set_defaults(func=benchmark)

In [4]:
# parse the arguments
args = parser.parse_args("train --epochs 20 --batch_size 32 --config config/pku-mmd/adapted_vsc.json".split())
args.func(args)

Training started


KeyboardInterrupt: 