In [1]:
import numpy as np
from tofu.modules import ToFULayer
import torch
from tofu.config import data_dir, model_dir
import pandas as pd
import h5py
import os

In [2]:
# utils

def diag_to_array(data):
    dataset, num_diag = [], len(data["0"].keys())
    for dim in data.keys():
        X = []
        for diag in range(num_diag):
            pers_diag = np.array(data[dim][str(diag)])
            X.append(pers_diag)
        dataset.append(X)
    return dataset

def diag_to_dict(D):
    X = dict()
    for f in D.keys():
        df = diag_to_array(D[f])
        for dim in range(len(df)):
            X[str(dim) + "_" + f] = df[dim]
    return X

def pad(dgm,max_feats):
    n_feats = dgm.shape[0]
    return np.pad(dgm,((0,max_feats-n_feats),(0,0)),constant_values=((100,100),(0,0)))

def scale_persistence(dgm):
    pers = dgm[:,1]
    scaled_pers = pers/np.sum(pers)
    return np.hstack((dgm[:,0].reshape((-1,1)),scaled_pers.reshape((-1,1))))

In [3]:
# load data
os.chdir(data_dir)
train_lab  = pd.read_csv("train.csv")
train_diag = diag_to_dict(h5py.File("train_diag.hdf5", "r"))

# data preprocessing
X,y = train_diag['1_geodesic'], np.array([int(lab) for lab in train_lab['part']])

max_feats = np.max([len(dgm) for dgm in X])
X = np.array([pad(dgm,max_feats = max_feats) for dgm in X]) 
X = np.array([np.hstack((x[:,0].reshape((-1,1)),(x[:,1]-x[:,0]).reshape(-1,1))) for x in X]) # birth persistence coordinates
X = np.array([scale_persistence(dgm) for dgm in X]) # scale persistence values

X,y = torch.tensor(X),torch.tensor(y)

In [4]:
# define model

class ShapeClassifier(torch.nn.Module):
    
    def __init__(self,n_dgms,n_feats,device,birth_lims,birth_death,lin_dim,n_classes):
        super(ShapeClassifier,self).__init__()
        
        self.tofu = ToFULayer(n_dgms,n_feats,device,birth_lims,birth_death)
        
        self.lin1 = torch.nn.Linear(n_dgms,lin_dim)
        self.lin2 = torch.nn.Linear(lin_dim,lin_dim)
        self.lin3 = torch.nn.Linear(lin_dim,lin_dim)
        
        self.lin4 = torch.nn.Linear(lin_dim,int(lin_dim/2))
        self.lin5 = torch.nn.Linear(int(lin_dim/2),int(lin_dim/2))
        
        self.lin6 = torch.nn.Linear(int(lin_dim/2),int(lin_dim/4))
        
        self.logits = torch.nn.Linear(int(lin_dim/4),n_classes)
        
    def forward(self,dgm_batch):
        
        h = self.tofu(dgm_batch)
        
        h = self.lin1(h)
        h = torch.nn.functional.silu(h)
        h = self.lin2(h)
        h = torch.nn.functional.silu(h)
        h = self.lin3(h)
        h = torch.nn.functional.silu(h)
        
        h = self.lin4(h)
        h = torch.nn.functional.silu(h)
        h = self.lin5(h)
        h = torch.nn.functional.silu(h)
        
        h = self.lin6(h)
        h = torch.nn.functional.silu(h)
        
        h = self.logits(h)
        
        return h

In [11]:
# build model
n_dgms,n_feats,device,birth_lims,birth_death,lin_dim,n_classes =  16, 1, torch.device("cuda" if torch.cuda.is_available() else "cpu"), [0,2], False, 256, 4

mod = ShapeClassifier(n_dgms,n_feats,device,birth_lims,birth_death,lin_dim,n_classes).to(device).float()

#os.chdir(model_dir)
#mod.load_state_dict(torch.load('shape_segmentation_classfier.pt'))

X,y = X.to(device).float(),y.to(device)

opt = torch.optim.Adam(mod.parameters(),lr = 1e-3, weight_decay = 1e-5)
#sched = torch.optim.lr_scheduler.ExponentialLR(opt,gamma = 0.9885)
criterion = torch.nn.CrossEntropyLoss()

In [12]:
# train model

n_epochs = 400
n_parts = 1
batch_size = 32

for i in range(n_parts):
    permutation = torch.randperm(X.size()[0])
    train_idx, val_idx = permutation[0:int(X.size()[0]/2)], permutation[int(X.size()[0]/2):]
    X_train,X_val,y_train,y_val = X[train_idx],X[val_idx],y[train_idx],y[val_idx]

    for epoch in range(n_epochs):

        permutation = torch.randperm(X_train.size()[0]) # shuffle X_train
        correct = 0
        for i in range(0,X_train.size()[0], batch_size):
            opt.zero_grad()

            indices = permutation[i:i+batch_size]
            batch_x, batch_y = X_train[indices], y_train[indices]

            outputs = mod(batch_x)
            loss = criterion(outputs,batch_y)
            
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == batch_y).sum().item()

            loss.backward()
            opt.step()
        #sched.step()
        
        # report accuracies
        train_acc = correct/X_train.shape[0]
        
        with torch.no_grad():
            outputs_val = mod(X_val)
        _, predicted = torch.max(outputs_val.data, 1)
        total = y_val.size(0)
        correct = (predicted == y_val).sum().item()
        val_acc = correct/total
        print('Epoch ' + str(epoch + 1) + ' Training Accuracy: ' + str(train_acc),', Validation Accuracy: ' + str(val_acc))
        


Epoch 1 Training Accuracy: 0.43719298245614036 , Validation Accuracy: 0.49333333333333335
Epoch 2 Training Accuracy: 0.48596491228070177 , Validation Accuracy: 0.4849122807017544
Epoch 3 Training Accuracy: 0.5336842105263158 , Validation Accuracy: 0.5407017543859649
Epoch 4 Training Accuracy: 0.5954385964912281 , Validation Accuracy: 0.6298245614035087
Epoch 5 Training Accuracy: 0.628421052631579 , Validation Accuracy: 0.6140350877192983
Epoch 6 Training Accuracy: 0.6168421052631579 , Validation Accuracy: 0.6014035087719298
Epoch 7 Training Accuracy: 0.6435087719298246 , Validation Accuracy: 0.68
Epoch 8 Training Accuracy: 0.6817543859649123 , Validation Accuracy: 0.6326315789473684
Epoch 9 Training Accuracy: 0.68 , Validation Accuracy: 0.6992982456140351
Epoch 10 Training Accuracy: 0.6957894736842105 , Validation Accuracy: 0.7049122807017544
Epoch 11 Training Accuracy: 0.6957894736842105 , Validation Accuracy: 0.6989473684210527
Epoch 12 Training Accuracy: 0.6877192982456141 , Validat

KeyboardInterrupt: 

In [9]:
os.chdir(model_dir)
torch.save(mod.state_dict(),'shape_segmentation_classfier.pt') # 212 epochs, 64, 1, torch.device("cuda" if torch.cuda.is_available() else "cpu"), [0,2], 256, 4