In [6]:
import numpy as np
from tofu.modules import ToFULayer
import torch
from tofu.config import data_dir, model_dir
from tofu.utils import motion_capture_pd_dim, motion_capture_pad
import os
from scipy.io import loadmat

In [7]:
# load data
os.chdir(data_dir)
data,labels = loadmat('PD.mat'), loadmat('Label_mocap.mat')
labels = (labels['Label']-1).flatten()
dgms = [dgm[1] for dgm in data['PD']]

In [8]:
# define model

class MotionClassifier(torch.nn.Module):
    
    def __init__(self,in_dim,lin_dim,n_classes):
        super(MotionClassifier,self).__init__()
        
        
        self.lin1 = torch.nn.Linear(in_dim,lin_dim)
        self.lin2 = torch.nn.Linear(lin_dim,lin_dim)
        self.lin3 = torch.nn.Linear(lin_dim,lin_dim)
        
        self.lin4 = torch.nn.Linear(lin_dim,int(lin_dim/2))
        self.lin5 = torch.nn.Linear(int(lin_dim/2),int(lin_dim/2))
        
        #self.lin6 = torch.nn.Linear(int(lin_dim/2),int(lin_dim/4))
        
        self.logits = torch.nn.Linear(int(lin_dim/2),n_classes)
        
    def forward(self,h_batch):
        
        
        h = self.lin1(h_batch)
        h = torch.nn.functional.silu(h)
        h = self.lin2(h)
        h = torch.nn.functional.silu(h)
        h = self.lin3(h)
        h = torch.nn.functional.silu(h)
        
        h = self.lin4(h)
        h = torch.nn.functional.silu(h)
        h = self.lin5(h)
        h = torch.nn.functional.silu(h)
        
        #h = self.lin6(h)
        #h = torch.nn.functional.silu(h)
        
        h = self.logits(h)
        
        return h

In [9]:
# build model
dgm_dims = 57
n_dgms,n_feats,device,birth_lims,birth_death,lin_dim,n_classes =  4, 1, torch.device("cuda" if torch.cuda.is_available() else "cpu"), [10,20], True, 256, 5
in_dim = int(n_dgms*dgm_dims)

ToFU_dim = torch.nn.ModuleList([ToFULayer(n_dgms,n_feats,device,birth_lims,birth_death).to(device).float() for d in range(dgm_dims)])
classifier = MotionClassifier(in_dim,lin_dim,n_classes).to(device).float()

opt_tofu, opt_classifier = torch.optim.Adam(ToFU_dim.parameters(),lr = 1e-2), torch.optim.Adam(classifier.parameters(),lr = 1e-2, weight_decay = 1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [10]:
# train 
n_obs = len(dgms)
train_perc = 0.8
epochs,batch_size = 100,128 

perm = np.random.choice(n_obs,n_obs,replace = False)
shuffled_data, shuffled_labels = [dgms[i] for i in perm], labels[perm]
X_train, y_train, X_val, y_val = shuffled_data[0:int(train_perc*n_obs)], shuffled_labels[0:int(train_perc*n_obs)], shuffled_data[int(train_perc*n_obs):], shuffled_labels[int(train_perc*n_obs):]


for ep in range(epochs):
    perm = np.random.choice(len(X_train),len(X_train),replace = False)
    X_epoch, y_epoch = [X_train[i] for i in perm], y_train[perm]
    correct = 0
    
    for i in range(0,len(X_train),batch_size):
        opt_tofu.zero_grad()
        opt_classifier.zero_grad()
        
        X_batch, y_batch = X_epoch[i:(i+batch_size)], y_epoch[i:(i+batch_size)]
        vectorizations = [ToFU_dim[d](torch.tensor(motion_capture_pad(motion_capture_pd_dim(X_batch,d))).to(device).float()) for d in range(dgm_dims)]
        v_batch = torch.cat(vectorizations,1)
        
        probs = classifier(v_batch)
        loss = criterion(probs,torch.tensor(y_batch).to(device).long())
        
        _, predicted = torch.max(probs.data, 1)
        correct += (predicted == torch.tensor(y_batch).to(device).long()).sum().item()
        
        loss.backward()
        opt_tofu.step()
        opt_classifier.step()
        
    # report accuracies
    train_acc = correct/len(X_train)
    
    # validate
    with torch.no_grad():
        vectorizations_val = [ToFU_dim[d](torch.tensor(motion_capture_pad(motion_capture_pd_dim(X_val,d))).to(device).float()) for d in range(dgm_dims)]
        v_val = torch.cat(vectorizations_val,1)
        outputs_val = classifier(v_val)
    
    _, predicted = torch.max(outputs_val.data, 1)
    total = len(y_val)
    correct = (predicted == torch.tensor(y_val).to(device).long()).sum().item()
    val_acc = correct/total
    print('Epoch ' + str(ep + 1) + ' Training Accuracy: ' + str(train_acc),', Validation Accuracy: ' + str(val_acc))
    

Epoch 1 Training Accuracy: 0.1349206349206349 , Validation Accuracy: 0.3125
Epoch 2 Training Accuracy: 0.30158730158730157 , Validation Accuracy: 0.1875
Epoch 3 Training Accuracy: 0.2698412698412698 , Validation Accuracy: 0.15625
Epoch 4 Training Accuracy: 0.23809523809523808 , Validation Accuracy: 0.25
Epoch 5 Training Accuracy: 0.24603174603174602 , Validation Accuracy: 0.21875
Epoch 6 Training Accuracy: 0.2857142857142857 , Validation Accuracy: 0.5
Epoch 7 Training Accuracy: 0.38095238095238093 , Validation Accuracy: 0.28125
Epoch 8 Training Accuracy: 0.30952380952380953 , Validation Accuracy: 0.46875
Epoch 9 Training Accuracy: 0.42063492063492064 , Validation Accuracy: 0.4375
Epoch 10 Training Accuracy: 0.42857142857142855 , Validation Accuracy: 0.53125
Epoch 11 Training Accuracy: 0.49206349206349204 , Validation Accuracy: 0.5625
Epoch 12 Training Accuracy: 0.5079365079365079 , Validation Accuracy: 0.5625
Epoch 13 Training Accuracy: 0.5 , Validation Accuracy: 0.5
Epoch 14 Training 