In [1]:
# %%
from typing import List
import numpy as np  # type: ignore
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import multiprocessing
import os
from einops import rearrange  # type: ignore
from gtda.homology import WeakAlphaPersistence  # type: ignore
from gtda.plotting import plot_diagram  # type: ignore
from gdeep.topology_layers import ISAB, PMA

# %%
try:
    assert os.path.isdir('./data/ORBIT5K')
except AssertionError:
    if not os.path.isdir('./data'):
        os.mkdir('./data')
    os.mkdir('./data/ORBIT5K')

# If `use_precomputed_dgms` is `False` the ORBIT5K dataset will
# be recomputed, otherwise the ORBIT5K dataset in the folder
# `data/ORBIT5K` will be used
use_precomputed_dgms = False

dgms_filename = os.path.join('data', 'ORBIT5K',
                             'alpha_persistence_diagrams.npy')
dgms_filename_validation = os.path.join('data', 'ORBIT5K',
                                        'alpha_persistence_diagrams_' +
                                        'validation.npy')

if use_precomputed_dgms:
    try:
        assert(os.path.isfile(dgms_filename))
    except AssertionError:
        print('File data/ORBIT5K/alpha_persistence_diagrams.npy',
              ' does not exist.')
    try:
        assert(os.path.isfile(dgms_filename_validation))
    except AssertionError:
        print('File data/ORBIT5K/alpha_persistence_diagrams.npy',
              ' does not exist.')

# %%
# Create ORBIT5K dataset like in the PersLay paper
parameters = (2.5, 3.5, 4.0, 4.1, 4.3)  # different classes of orbits
homology_dimensions = (0, 1)

config = {
    'parameters': parameters,
    'num_classes': len(parameters),
    'num_orbits': 1_000,  # number of orbits per class
    'num_pts_per_orbit': 1_000,
    'homology_dimensions': homology_dimensions,
    'num_homology_dimensions': len(homology_dimensions),
    'validation_percentage': 100,  # size of validation dataset relative
    # to training
}

if not use_precomputed_dgms:
    for dataset_type in ['train', 'validation']:
        # Generate dataset consisting of 5 different orbit types with
        # 1000 sampled data points each.
        # This is the dataset ORBIT5K used in the PersLay paper
        if dataset_type == 'train':
            num_orbits = config['num_orbits']
        else:
            num_orbits = int(config['num_orbits']  # type: ignore
                             * config['validation_percentage'] / 100)
        x = np.zeros((
                        config['num_classes'],  # type: ignore
                        num_orbits,
                        config['num_pts_per_orbit'],
                        2
                    ))

        # generate dataset
        for cidx, p in enumerate(config['parameters']):  # type: ignore
            x[cidx, :, 0, :] = np.random.rand(num_orbits, 2)

            for i in range(1, config['num_pts_per_orbit']):  # type: ignore
                x_cur = x[cidx, :, i - 1, 0]
                y_cur = x[cidx, :, i - 1, 1]

                x[cidx, :, i, 0] = (x_cur + p * y_cur * (1. - y_cur)) % 1
                x_next = x[cidx, :, i, 0]
                x[cidx, :, i, 1] = (y_cur + p * x_next * (1. - x_next)) % 1

        """
        # old non-parallel version
        for cidx, p in enumerate(config['parameters']):  # type: ignore
            for i in range(config['num_orbits']):  # type: ignore
                x[cidx][i] = generate_orbit(
                    num_pts_per_orbit=config['num_pts_per_orbit'],
                    parameter=p
                    )
        """

        assert(not np.allclose(x[0, 0], x[0, 1]))

        # compute weak alpha persistence
        wap = WeakAlphaPersistence(
                            homology_dimensions=config['homology_dimensions'],
                            n_jobs=multiprocessing.cpu_count()
                            )
        # c: class, o: orbit, p: point, d: dimension
        x_stack = rearrange(x, 'c o p d -> (c o) p d')  # stack classes
        diagrams = wap.fit_transform(x_stack)
        # shape: (num_classes * n_samples, n_features, 3)

        # combine class and orbit dimensions
        diagrams = rearrange(
                                diagrams,
                                '(c o) p d -> c o p d',
                                c=config['num_classes']  # type: ignore
                            )

        # plot sample persistence diagrams for debugging
        if(False):
            plot_diagram(diagrams[1, 2])
            plot_diagram(diagrams[2, 2])

        # save dataset
        if dataset_type == 'train':
            with open(dgms_filename, 'wb') as f:
                np.save(f, diagrams)
        else:
            with open(dgms_filename_validation, 'wb') as f:
                np.save(f, diagrams)
# %%
# load dataset
for dataset_type in ['train', 'validation']:

    with open(dgms_filename, 'rb') as f:
        x = np.load(f)

    # c: class, o: orbit, p: point in persistence diagram,
    # d: coordinates + homology dimension
    x = rearrange(
                    x,
                    'c o p d -> (c o) p d',
                    c=config['num_classes']  # type: ignore
                )
    # convert homology dimension to one-hot encoding
    x = np.concatenate(
        (
            x[:, :, :2],
            (np.eye(config['num_homology_dimensions'])
             [x[:, :, -1].astype(np.int32)]),
        ),
        axis=-1)
    # convert from [orbit, sequence_length, feature] to
    # [orbit, feature, sequence_length] to fit to the
    # input_shape of `SmallSetTransformer`
    # x = rearrange(x, 'o s f -> o f s')

    # generate labels
    y_list = []
    for i in range(config['num_classes']):  # type: ignore
        y_list += [i] * config['num_orbits']  # type: ignore

    y = np.array(y_list)

    # load dataset to PyTorch dataloader

    x_tensor = torch.Tensor(x)
    y_tensor = torch.Tensor(y)

    dataset = TensorDataset(x_tensor, y_tensor)
    if dataset_type == 'train':
        dataloader = DataLoader(dataset,
                                shuffle=True,
                                batch_size=2 ** 6,
                                num_workers=6)
    else:
        dataloader_validation = DataLoader(dataset,
                                           batch_size=2 ** 6,
                                           num_workers=6)

# %%


# initialize SetTransformer model
class SetTransformer(nn.Module):
    """ Vanilla SetTransformer from
    https://github.com/juho-lee/set_transformer/blob/master/main_pointcloud.py
    """
    def __init__(
        self,
        dim_input=3,  # dimension of input data for each element in the set
        num_outputs=1,
        dim_output=40,  # number of classes
        num_inds=32,  # number of induced points, see  Set Transformer paper
        dim_hidden=128,
        num_heads=4,
        ln=False,  # use layer norm
    ):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
            ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
            ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
        )
        self.dec = nn.Sequential(
            nn.Dropout(),
            PMA(dim_hidden, num_heads, num_outputs, ln=ln),
            nn.Dropout(),
            nn.Linear(dim_hidden, dim_output),
        )

    def forward(self, input):
        return self.dec(self.enc(input)).squeeze()


model = SetTransformer(dim_input=4, dim_output=5)


def num_params(model: nn.Module) -> int:
    return sum([parameter.nelement() for parameter in model.parameters()])


print('model has', num_params(model), 'trainable parameters.')  # type: ignore

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")


# %%
def train(model, num_epochs: int = 10, lr: float = 1e-3,
          verbose: bool = False) -> List[float]:
    """Custom training loop for Set Transformer on the dataset ``
    Args:
        model (nn.Module): Set Transformer model to be trained
        num_epochs (int, optional): Number of training epochs. Defaults to 10.
        lr (float, optional): Learning rate for training. Defaults to 1e-3.
        verbose (bool, optional): Print training loss, training accuracy and
            validation if set to True. Defaults to False.
    Returns:
        List[float]: List of training losses
    """
    if use_cuda:
        model = nn.DataParallel(model)
        model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses: List[float] = []
    # training loop
    for epoch in range(num_epochs):
        model.train()
        loss_per_epoch = 0
        for x_batch, y_batch in dataloader:
            # transfer to GPU
            if use_cuda:
                x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
            loss = criterion(model(x_batch), y_batch.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_per_epoch += loss.item()
        losses.append(loss_per_epoch)
        if verbose:
            print("epoch:", epoch, "loss:", loss_per_epoch)
            compute_accuracy(model, 'test')
            compute_accuracy(model, 'validation')
    return losses


def compute_accuracy(model, type: str = 'test') -> None:
    correct = 0
    total = 0
    if type == 'test':
        dl = dataloader
    else:
        dl = dataloader_validation

    with torch.no_grad():
        for x_batch, y_batch in dl:
            if use_cuda:
                x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
            outputs = model(x_batch).squeeze(1)
            _, predictions = torch.max(outputs, 1)
            total += y_batch.size(0)
            correct += (predictions == y_batch).sum().item()

    print(type.capitalize(),
          'accuracy of the network on the', total,
          'diagrams: %8.2f %%' % (100 * correct / total)
          )


# %%
train(model, num_epochs=500, verbose=True)

model has 291589 trainable parameters.
epoch: 0 loss: 128.43465423583984
Test accuracy of the network on the 5000 diagrams:    19.54 %
Validation accuracy of the network on the 5000 diagrams:    19.82 %
epoch: 1 loss: 127.78605687618256
Test accuracy of the network on the 5000 diagrams:    19.48 %
Validation accuracy of the network on the 5000 diagrams:    20.54 %
epoch: 2 loss: 127.54263913631439
Test accuracy of the network on the 5000 diagrams:    20.34 %
Validation accuracy of the network on the 5000 diagrams:    19.36 %
epoch: 3 loss: 127.70188856124878
Test accuracy of the network on the 5000 diagrams:    19.80 %
Validation accuracy of the network on the 5000 diagrams:    18.98 %
epoch: 4 loss: 127.40705442428589
Test accuracy of the network on the 5000 diagrams:    19.92 %
Validation accuracy of the network on the 5000 diagrams:    18.54 %
epoch: 5 loss: 127.43623185157776
Test accuracy of the network on the 5000 diagrams:    20.66 %
Validation accuracy of the network on the 500

Validation accuracy of the network on the 5000 diagrams:    31.18 %
epoch: 50 loss: 112.37828552722931
Test accuracy of the network on the 5000 diagrams:    33.20 %
Validation accuracy of the network on the 5000 diagrams:    32.60 %
epoch: 51 loss: 111.09530258178711
Test accuracy of the network on the 5000 diagrams:    41.90 %
Validation accuracy of the network on the 5000 diagrams:    41.70 %
epoch: 52 loss: 100.15738624334335
Test accuracy of the network on the 5000 diagrams:    51.46 %
Validation accuracy of the network on the 5000 diagrams:    51.88 %
epoch: 53 loss: 93.2746929526329
Test accuracy of the network on the 5000 diagrams:    44.22 %
Validation accuracy of the network on the 5000 diagrams:    45.50 %
epoch: 54 loss: 83.0915042757988
Test accuracy of the network on the 5000 diagrams:    64.26 %
Validation accuracy of the network on the 5000 diagrams:    64.32 %
epoch: 55 loss: 73.79250580072403
Test accuracy of the network on the 5000 diagrams:    63.78 %
Validation accu

Validation accuracy of the network on the 5000 diagrams:    87.86 %
epoch: 100 loss: 22.610743328928947
Test accuracy of the network on the 5000 diagrams:    87.82 %
Validation accuracy of the network on the 5000 diagrams:    87.30 %
epoch: 101 loss: 25.855116590857506
Test accuracy of the network on the 5000 diagrams:    89.30 %
Validation accuracy of the network on the 5000 diagrams:    89.46 %
epoch: 102 loss: 24.03989939391613
Test accuracy of the network on the 5000 diagrams:    88.30 %
Validation accuracy of the network on the 5000 diagrams:    88.60 %
epoch: 103 loss: 23.64241924136877
Test accuracy of the network on the 5000 diagrams:    89.78 %
Validation accuracy of the network on the 5000 diagrams:    89.40 %
epoch: 104 loss: 23.48386822640896
Test accuracy of the network on the 5000 diagrams:    88.56 %
Validation accuracy of the network on the 5000 diagrams:    88.14 %
epoch: 105 loss: 22.80709721893072
Test accuracy of the network on the 5000 diagrams:    88.86 %
Validati

Test accuracy of the network on the 5000 diagrams:    88.86 %
Validation accuracy of the network on the 5000 diagrams:    89.16 %
epoch: 150 loss: 22.03894168883562
Test accuracy of the network on the 5000 diagrams:    84.84 %
Validation accuracy of the network on the 5000 diagrams:    84.50 %
epoch: 151 loss: 23.845813617110252
Test accuracy of the network on the 5000 diagrams:    87.80 %
Validation accuracy of the network on the 5000 diagrams:    87.74 %
epoch: 152 loss: 21.73010703921318
Test accuracy of the network on the 5000 diagrams:    89.36 %
Validation accuracy of the network on the 5000 diagrams:    89.32 %
epoch: 153 loss: 22.234844408929348
Test accuracy of the network on the 5000 diagrams:    90.08 %
Validation accuracy of the network on the 5000 diagrams:    90.20 %
epoch: 154 loss: 20.836495995521545
Test accuracy of the network on the 5000 diagrams:    89.32 %
Validation accuracy of the network on the 5000 diagrams:    89.42 %
epoch: 155 loss: 22.36608485877514
Test ac

epoch: 199 loss: 21.43756289035082
Test accuracy of the network on the 5000 diagrams:    89.60 %
Validation accuracy of the network on the 5000 diagrams:    89.68 %
epoch: 200 loss: 20.88115794956684
Test accuracy of the network on the 5000 diagrams:    90.06 %
Validation accuracy of the network on the 5000 diagrams:    90.02 %
epoch: 201 loss: 20.24783929437399
Test accuracy of the network on the 5000 diagrams:    89.86 %
Validation accuracy of the network on the 5000 diagrams:    89.84 %
epoch: 202 loss: 21.424878649413586
Test accuracy of the network on the 5000 diagrams:    86.98 %
Validation accuracy of the network on the 5000 diagrams:    87.06 %
epoch: 203 loss: 21.2998256534338
Test accuracy of the network on the 5000 diagrams:    90.52 %
Validation accuracy of the network on the 5000 diagrams:    90.50 %
epoch: 204 loss: 20.096501775085926
Test accuracy of the network on the 5000 diagrams:    89.66 %
Validation accuracy of the network on the 5000 diagrams:    89.76 %
epoch: 20

Validation accuracy of the network on the 5000 diagrams:    89.68 %
epoch: 249 loss: 20.829761400818825
Test accuracy of the network on the 5000 diagrams:    89.70 %
Validation accuracy of the network on the 5000 diagrams:    89.32 %
epoch: 250 loss: 20.51490778476
Test accuracy of the network on the 5000 diagrams:    90.66 %
Validation accuracy of the network on the 5000 diagrams:    90.34 %
epoch: 251 loss: 20.657420001924038
Test accuracy of the network on the 5000 diagrams:    89.78 %
Validation accuracy of the network on the 5000 diagrams:    90.02 %
epoch: 252 loss: 21.335706397891045
Test accuracy of the network on the 5000 diagrams:    90.00 %
Validation accuracy of the network on the 5000 diagrams:    90.02 %
epoch: 253 loss: 20.29035521298647
Test accuracy of the network on the 5000 diagrams:    89.92 %
Validation accuracy of the network on the 5000 diagrams:    90.18 %
epoch: 254 loss: 19.475899402052164
Test accuracy of the network on the 5000 diagrams:    90.14 %
Validatio

Test accuracy of the network on the 5000 diagrams:    90.64 %
Validation accuracy of the network on the 5000 diagrams:    90.54 %
epoch: 299 loss: 19.99121691286564
Test accuracy of the network on the 5000 diagrams:    90.44 %
Validation accuracy of the network on the 5000 diagrams:    90.08 %
epoch: 300 loss: 19.91942585259676
Test accuracy of the network on the 5000 diagrams:    90.32 %
Validation accuracy of the network on the 5000 diagrams:    90.42 %
epoch: 301 loss: 19.31378647685051
Test accuracy of the network on the 5000 diagrams:    89.00 %
Validation accuracy of the network on the 5000 diagrams:    89.16 %
epoch: 302 loss: 21.127029789146036
Test accuracy of the network on the 5000 diagrams:    90.44 %
Validation accuracy of the network on the 5000 diagrams:    90.56 %
epoch: 303 loss: 20.165587715804577
Test accuracy of the network on the 5000 diagrams:    89.28 %
Validation accuracy of the network on the 5000 diagrams:    89.72 %
epoch: 304 loss: 21.890318870544434
Test ac

epoch: 348 loss: 20.180256314575672
Test accuracy of the network on the 5000 diagrams:    90.22 %
Validation accuracy of the network on the 5000 diagrams:    90.16 %
epoch: 349 loss: 19.756914764642715
Test accuracy of the network on the 5000 diagrams:    90.66 %
Validation accuracy of the network on the 5000 diagrams:    90.36 %
epoch: 350 loss: 19.208304330706596
Test accuracy of the network on the 5000 diagrams:    90.70 %
Validation accuracy of the network on the 5000 diagrams:    90.42 %
epoch: 351 loss: 19.684471115469933
Test accuracy of the network on the 5000 diagrams:    90.62 %
Validation accuracy of the network on the 5000 diagrams:    90.16 %
epoch: 352 loss: 19.339708156883717
Test accuracy of the network on the 5000 diagrams:    90.08 %
Validation accuracy of the network on the 5000 diagrams:    90.42 %
epoch: 353 loss: 18.996000107377768
Test accuracy of the network on the 5000 diagrams:    90.50 %
Validation accuracy of the network on the 5000 diagrams:    90.58 %
epoc

Validation accuracy of the network on the 5000 diagrams:    90.16 %
epoch: 398 loss: 21.106209307909012
Test accuracy of the network on the 5000 diagrams:    90.32 %
Validation accuracy of the network on the 5000 diagrams:    90.44 %
epoch: 399 loss: 19.584492050111294
Test accuracy of the network on the 5000 diagrams:    90.34 %
Validation accuracy of the network on the 5000 diagrams:    90.56 %
epoch: 400 loss: 19.170945189893246
Test accuracy of the network on the 5000 diagrams:    90.84 %
Validation accuracy of the network on the 5000 diagrams:    90.62 %
epoch: 401 loss: 19.472616732120514
Test accuracy of the network on the 5000 diagrams:    89.76 %
Validation accuracy of the network on the 5000 diagrams:    89.96 %
epoch: 402 loss: 18.644321374595165
Test accuracy of the network on the 5000 diagrams:    90.10 %
Validation accuracy of the network on the 5000 diagrams:    89.96 %
epoch: 403 loss: 19.35397431999445
Test accuracy of the network on the 5000 diagrams:    90.62 %
Valid

Test accuracy of the network on the 5000 diagrams:    89.44 %
Validation accuracy of the network on the 5000 diagrams:    89.54 %
epoch: 448 loss: 20.022407859563828
Test accuracy of the network on the 5000 diagrams:    89.58 %
Validation accuracy of the network on the 5000 diagrams:    89.86 %
epoch: 449 loss: 19.672982282936573
Test accuracy of the network on the 5000 diagrams:    90.62 %
Validation accuracy of the network on the 5000 diagrams:    90.44 %
epoch: 450 loss: 23.54917772114277
Test accuracy of the network on the 5000 diagrams:    90.02 %
Validation accuracy of the network on the 5000 diagrams:    90.08 %
epoch: 451 loss: 20.446417346596718
Test accuracy of the network on the 5000 diagrams:    90.42 %
Validation accuracy of the network on the 5000 diagrams:    90.32 %
epoch: 452 loss: 20.00468249619007
Test accuracy of the network on the 5000 diagrams:    90.44 %
Validation accuracy of the network on the 5000 diagrams:    90.18 %
epoch: 453 loss: 19.84957218170166
Test ac

epoch: 497 loss: 19.485870704054832
Test accuracy of the network on the 5000 diagrams:    90.50 %
Validation accuracy of the network on the 5000 diagrams:    90.78 %
epoch: 498 loss: 19.158110089600086
Test accuracy of the network on the 5000 diagrams:    90.54 %
Validation accuracy of the network on the 5000 diagrams:    90.14 %
epoch: 499 loss: 19.535780772566795
Test accuracy of the network on the 5000 diagrams:    90.74 %
Validation accuracy of the network on the 5000 diagrams:    90.38 %


[128.43465423583984,
 127.78605687618256,
 127.54263913631439,
 127.70188856124878,
 127.40705442428589,
 127.43623185157776,
 127.59978806972504,
 127.40756225585938,
 127.34827077388763,
 127.44339752197266,
 127.45875382423401,
 127.43387293815613,
 127.29627668857574,
 127.31724560260773,
 127.31289553642273,
 127.2538241147995,
 127.20631670951843,
 127.32400703430176,
 127.3118587732315,
 127.35974156856537,
 127.26410210132599,
 127.2933304309845,
 127.24492263793945,
 127.19129824638367,
 127.23867499828339,
 127.18598818778992,
 127.14183461666107,
 127.16810071468353,
 127.19790613651276,
 126.83528518676758,
 126.64188396930695,
 127.37898421287537,
 126.6927570104599,
 127.03667950630188,
 127.1336499452591,
 127.02402079105377,
 126.46275770664215,
 124.64741706848145,
 122.9453774690628,
 125.90857565402985,
 124.28945708274841,
 121.88105547428131,
 125.07325732707977,
 123.54615831375122,
 121.55071663856506,
 117.56634104251862,
 116.44311809539795,
 117.97319543361664

In [2]:
import matplotlib.pyplot as plt
