### Imports

In [None]:
import numpy as np
import pandas as pd
import scipy.io as io
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass, field, asdict
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.nn import init
from torch.optim import lr_scheduler
from pytorch_lightning import Trainer
from torch.autograd import Variable, Function
from pytorch_lightning.callbacks import EarlyStopping
from torch.nn.init import kaiming_normal_, orthogonal_
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn import metrics
import json
import torch
import math
import time
import warnings
import os
import h5py
import tables
warnings.filterwarnings("ignore")

In [None]:
@dataclass
class paramclass:
    lr = 1e-03
    l = 0.01
    batch_size: int = 64
    routing_iterations: int = 5
    n_classes: int = 8
    max_epochs = 300
    device = torch.device('cuda:0')
    num_workers: int = 0
    first_training = True
    resume: str = ''

params = paramclass()
params.train_paths = ['./data/realworld2016_dataset/proband3/processed/proband3.hdf',
                      './data/realworld2016_dataset/proband5/processed/proband5.hdf',
                      './data/realworld2016_dataset/proband6/processed/proband6.hdf',
                      './data/realworld2016_dataset/proband8/processed/proband8.hdf',
                      './data/realworld2016_dataset/proband9/processed/proband9.hdf',
                      './data/realworld2016_dataset/proband12/processed/proband12.hdf',
                      './data/realworld2016_dataset/proband13/processed/proband13.hdf',
                      './data/realworld2016_dataset/proband15/processed/proband15.hdf']

params.val_path = './data/realworld2016_dataset/proband10/processed/proband10.hdf'
params.test_path = './data/realworld2016_dataset/proband11/processed/proband11.hdf'

In [None]:
class dataset(torch.utils.data.Dataset):
    def __init__(self, params, type='train', device=torch.device('cuda:0')):
        self.params = params
        self.device = device
        X = []
        y = []
        for path in self.params.train_paths:
            with h5py.File(path, 'r') as train:
                X_train = train['/X'][:].transpose(0, 2, 1).reshape(train['/X'][:].shape[0], 42, 250, 1)
                y_train = train['/y'][:]
                y_temp = np.zeros((y_train.shape[0], 8))
                y_temp[np.arange(y_train.shape[0]), y_train.astype('int').squeeze()] = 1
                X.append(X_train)
                y.append(y_temp)
        X_train = np.concatenate(X, axis=0)
        y_train = np.concatenate(y, axis=0)

        with h5py.File(self.params.test_path, 'r') as test:
            X_test = test['/X'][:].transpose(0, 2, 1).reshape(test['/X'][:].shape[0], 42, 250, 1) 
            y_temp = test['/y'][:]
            y_test = np.zeros((y_temp.shape[0], 8))
            y_test[np.arange(y_temp.shape[0]), y_temp.astype('int').squeeze()] = 1
        
        with h5py.File(self.params.val_path, 'r') as val:
            X_val = val['/X'][:].transpose(0, 2, 1).reshape(val['/X'][:].shape[0], 42, 250, 1) 
            y_temp = val['/y'][:]
            y_val = np.zeros((y_temp.shape[0], 8))
            y_val[np.arange(y_temp.shape[0]), y_temp.astype('int').squeeze()] = 1
            
        for i in range(18):
            train_mean = np.mean(X_train[:,i,:,:])
            train_std = np.std(X_train[:,i,:,:])
            X_train[:,i,:,:] = (X_train[:,i,:,:] - train_mean)/train_std
            X_test[:,i,:,:] = (X_test[:,i,:,:] - train_mean)/train_std
            X_val[:,i,:,:] = (X_val[:,i,:,:] - train_mean)/train_std
        
        X_train = X_train.reshape([X_train.shape[0], 42, 250, 1])[:, :, :128, :]
        X_test = X_test.reshape([X_test.shape[0], 42, 250, 1])[:, :, :128, :]
        X_val = X_val.reshape([X_val.shape[0], 42, 250, 1])[:, :, :128, :]
        X_train = X_train[:, :, :, 0]
        X_test = X_test[:, :, :, 0]
        X_val = X_val[:, :, :, 0]

        if type=='train':
            self.X = X_train
            self.y = y_train
        elif type=='test':
            self.X = X_test
            self.y = y_test
        else:
            self.X = X_val
            self.y = y_val
        self.datalen = self.X.shape[0]

    def preprocess(self, data):
        B, H, W = data.shape
        data = data.view(B, 1, H, W)
        imu1 = data[:, :,   :6,   :]
        imu2 = data[:, :,  6:12,  :]
        imu3 = data[:, :, 12:18,  :]
        imu4 = data[:, :, 18:24,  :]
        imu5 = data[:, :, 24:30,  :]
        imu6 = data[:, :, 30:36,  :]
        imu7 = data[:, :, 36:42,  :]
        return imu1.float(), imu2.float(), imu3.float(), imu4.float(), imu5.float(), imu6.float(), imu7.float()

    def __len__(self):
        return self.datalen

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        data = self.X[idx]
        target = self.y[idx]
        return data, target

def conv2d(batchNorm, in_channels, out_channels, kernel_size, stride=1):
    if batchNorm:
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Dropout2d(p=0.35)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=True),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Dropout2d(p=0.35)
        )

def squash(x):
    lengths2 = x.pow(2).sum(dim=2)
    lengths = lengths2.sqrt()
    x = x * (lengths2 / (1 + lengths2) / lengths).view(x.size(0), x.size(1), 1)
    return x

class MarginLoss(pl.LightningModule):
    def __init__(self, m_pos, m_neg, lambda_):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, lengths, targets, size_average=True):
        losses = targets.float() * F.relu(self.m_pos - lengths).pow(2) + \
                 self.lambda_ * (1. - targets.float()) * F.relu(lengths - self.m_neg).pow(2)
        return losses.mean() if size_average else losses.sum()

class Large_Encoder(pl.LightningModule):
    def __init__(self, batchNorm):
        super(Large_Encoder,self).__init__()

        self.batchNorm = batchNorm

        self.conv1   = conv2d(self.batchNorm,   1,   48, kernel_size=(1, 10), stride=(1, 2))
        self.conv2   = conv2d(self.batchNorm,  48,  64, kernel_size=(3, 10), stride=(3, 2))
        self.conv3   = conv2d(self.batchNorm, 64,  96, kernel_size=(2, 15), stride=1)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                kaiming_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x
    
    
class AgreementRouting(pl.LightningModule):
    def __init__(self, input_caps, output_caps, n_iterations, l):
        super(AgreementRouting, self).__init__()
        self.n_iterations = n_iterations
        self.b = nn.Parameter(torch.zeros((input_caps, output_caps)))
        self.l = l

    def forward(self, u_predict):
        batch_size, input_caps, output_caps, output_dim = u_predict.size()

        c = F.softmax(self.b)
        s = (c.unsqueeze(2) * u_predict).sum(dim=1)
        v = squash(s)

        if self.n_iterations > 0:
            b_batch = self.b.expand((batch_size, input_caps, output_caps))
            for r in range(self.n_iterations):
                v = v.unsqueeze(1)
                b_batch = (1-self.l)*b_batch + self.l*(u_predict * v).sum(-1)

                c = F.softmax(b_batch.view(-1, output_caps)).view(-1, input_caps, output_caps, 1)
                s = (c * u_predict).sum(dim=1)
                v = squash(s)
        return v

    
class CapsLayer(pl.LightningModule):
    def __init__(self, input_caps, input_dim, output_caps, output_dim, routing_module):
        super(CapsLayer, self).__init__()
        self.input_dim = input_dim
        self.input_caps = input_caps
        self.output_dim = output_dim
        self.output_caps = output_caps
        self.weights = nn.Parameter(torch.Tensor(input_caps, input_dim, output_caps * output_dim))
        self.routing_module = routing_module
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.input_caps)
        self.weights.data.uniform_(-stdv, stdv)

    def forward(self, caps_output):
        caps_output = caps_output.unsqueeze(2)
        u_predict = caps_output.matmul(self.weights)
        u_predict = u_predict.view(u_predict.size(0), self.input_caps, self.output_caps, self.output_dim)
        v = self.routing_module(u_predict)
        return v

    
class PrimaryCapsLayer(pl.LightningModule):
    def __init__(self, split=False, output_caps=None, output_dim=None):
        super(PrimaryCapsLayer, self).__init__()
        self.split = split
        self.output_caps = output_caps
        self.output_dim = output_dim
    def forward(self, input):
        [i1, i2, i3, i4, i5, i6, i7] = input
        if self.split:
            B, C, H, W = i1.shape
            pc = []
            for idx in range(len(input)):
                i = input[idx].view(B, self.output_caps, self.output_dim, H, W)
                i = i.permute(0, 1, 3, 4, 2).contiguous()
                i = i.view(i.size(0), -1, i.size(4))
                pc.append(i)
            out = torch.cat(pc, dim=1)
        else:
            B, C, H, W = i1.shape
            pc = []
            for idx in range(len(input)):
                i = input[idx].permute(0, 2, 3, 1).contiguous()
                i = i.view(B, H*W, C)
                pc.append(i)
            out = torch.cat(pc, dim=1)
        out = squash(out)
        return out

class CapsNet(pl.LightningModule):
    def __init__(self, hparams):
        super(CapsNet, self).__init__()
        if type(hparams)==dict:
            self.hparams = hparams
            self.params = paramclass(**hparams)
        else:
            self.hparams = asdict(hparams)
            self.params = hparams

        self.pp = Large_Encoder(False)
        self.primaryCaps = PrimaryCapsLayer()
        self.num_primaryCaps = 1*12*7
        routing_module = AgreementRouting(self.num_primaryCaps, self.params.n_classes, self.params.routing_iterations, self.params.l)
        self.activityCaps = CapsLayer(self.num_primaryCaps, 96, self.params.n_classes, 16, routing_module)
            
        if not self.params.first_training:
            self.load_from_checkpoint(self.params.resume)

        self.loss_fn = MarginLoss(0.95, 0.05, 0.5)
        
    def prepare_data(self):
        self.traindataArr = dataset(self.params, type='train')
        self.valdataArr = dataset(self.params, type='val')

    def train_dataloader(self):
        return DataLoader(self.traindataArr, batch_size=self.params.batch_size, shuffle=True, num_workers=self.params.num_workers, drop_last=True)
    
    def val_dataloader(self):
        return DataLoader(self.valdataArr, batch_size=1, shuffle=False, num_workers=0)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.params.lr)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.975, last_epoch=-1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        data, target = batch
        imu1, imu2, imu3, imu4, imu5, imu6, imu7 = self.traindataArr.preprocess(data)
        output, probs = self([imu1, imu2, imu3, imu4, imu5, imu6, imu7])
        loss = self.loss_fn(probs, target)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}
    
    def validation_step(self, batch, batch_idx):
        data, target = batch
        imu1, imu2, imu3, imu4, imu5, imu6, imu7 = self.traindataArr.preprocess(data)
        output, probs = self([imu1, imu2, imu3, imu4, imu5, imu6, imu7])
        loss = self.loss_fn(probs, target)
        return {'val_loss': loss.data, 'Target': target, 'Predictions': probs}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        targets = torch.stack([x['Target'] for x in outputs])[:, 0, :].cpu().numpy()
        predictions = torch.stack([x['Predictions'] for x in outputs])[:, 0, :].cpu().numpy() 
        temp = np.where(predictions==predictions.max(axis=1, keepdims=True))[1]
        hotpredictions = np.zeros((temp.size, 8))
        hotpredictions[np.arange(temp.size), temp] = 1
        f1_mac = metrics.f1_score(targets, hotpredictions, average='macro')
        f1_mic = metrics.f1_score(targets, hotpredictions, average='micro')
        acc = metrics.accuracy_score(targets, hotpredictions)
        precision_avg, recall_avg, f_score_avg,_avg = metrics.precision_recall_fscore_support(targets, hotpredictions, average='weighted')
        tensorboard_logs = {'avg_val_loss': avg_loss,
                            'F1-Micro': f1_mic,
                            'F1-Macro': f1_mac,
                            'Accuracy': acc,
                            'Precision_Avg': precision_avg,
                            'Recall_Avg': recall_avg,
                            'F_score_Avg': f_score_avg
                            }
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
                
    def forward(self, input):
        capsules = []
        for imu in input:
            capsules.append(self.pp(imu))
        x = self.primaryCaps(capsules)
        x = self.activityCaps(x)
        probs = x.pow(2).sum(dim=2).sqrt()
        return x, probs
    
class Non_val_epoch_saves(pl.Callback):
    def __init__(self, iteration, filepath):
        self.iteration = iteration
        self.filepath = filepath    
    def on_epoch_end(self, trainer, pl_module):
        self.name = self.iteration + '_Epoch=' + str(trainer.current_epoch) + '.ckpt'
        trainer.checkpoint_callback._save_model(filepath=os.path.join(self.filepath, self.name))

In [None]:
model = CapsNet(params)
iteration = 'ARCNet-Large-RealWorld'
callback_dir = './checkpoints/'

trainer = Trainer(gpus=1,
                  default_save_path='./checkpoints/',
                  track_grad_norm=2,
                  max_epochs=params.max_epochs,
                  progress_bar_refresh_rate=50,
                  weights_summary='top',
                  fast_dev_run = True,
                  callbacks=[Non_val_epoch_saves(iteration=iteration, filepath=callback_dir)])
trainer.fit(model)