In [1]:
import numpy as np  # type: ignore
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import gc
import multiprocessing
import os
from einops import rearrange  # type: ignore
from gtda.homology import WeakAlphaPersistence  # type: ignore
from gtda.plotting import plot_diagram  # type: ignore
from gdeep.create_data import generate_orbit
from gdeep.topology_layers import SmallSetTransformer, ISAB, PMA, SAB

In [8]:
try:
    assert os.path.isdir('./data/ORBIT5K')
except AssertionError:
    if not os.path.isdir('./data'):
        os.mkdir('./data')
    os.mkdir('./data/ORBIT5K')

# If `use_precomputed_dgms` is `False` the ORBIT5K dataset will
# be recomputed, otherwise the ORBIT5K dataset in the folder
# `data/ORBIT5K` will be used
use_precomputed_dgms = False

dgms_filename = os.path.join('data', 'ORBIT5K',
                             'alpha_persistence_diagrams.npy')
dgms_filename_validation = os.path.join('data', 'ORBIT5K',
                             'alpha_persistence_diagrams_validation.npy')

if use_precomputed_dgms:
    try:
        assert(os.path.isfile(dgms_filename))
    except AssertionError:
        print('File data/ORBIT5K/alpha_persistence_diagrams.npy',
              ' does not exist.')
    try:
        assert(os.path.isfile(dgms_filename_validation))
    except AssertionError:
        print('File data/ORBIT5K/alpha_persistence_diagrams.npy',
              ' does not exist.')


In [9]:

parameters = (2.5, 3.5, 4.0, 4.1, 4.3)  # different classes of orbits
homology_dimensions = (0, 1)

config = {
    'parameters': parameters,
    'num_classes': len(parameters),
    'num_orbits': 1_000,
    'num_pts_per_orbit': 1_000,
    'homology_dimensions': homology_dimensions,
    'num_homology_dimensions': len(homology_dimensions),
    'validation_percentage': 100,
}

if not use_precomputed_dgms:
    for dataset_type in ['train', 'validation']:
        # Generate dataset consisting of 5 different orbit types with
        # 1000 sampled data points each.
        # This is the dataset ORBIT5K used in the PersLay paper
        if dataset_type == 'train':
            num_orbits = config['num_orbits']
        else:
            num_orbits = int(config['num_orbits'] 
                            * config['validation_percentage'] / 100)
        x = np.zeros((
                        config['num_classes'],  # type: ignore
                        num_orbits,
                        config['num_pts_per_orbit'],
                        2
                    ))

        # generate dataset
        for cidx, p in enumerate(config['parameters']):  # type: ignore
            x[cidx, :, 0, :] = np.random.rand(num_orbits, 2)

            for i in range(1, config['num_pts_per_orbit']):  # type: ignore
                x_cur = x[cidx, :, i - 1, 0]
                y_cur = x[cidx, :, i - 1, 1]

                x[cidx, :, i, 0] = (x_cur + p * y_cur * (1. - y_cur)) % 1
                x_next = x[cidx, :, i, 0]
                x[cidx, :, i, 1] = (y_cur + p * x_next * (1. - x_next)) % 1


        """
        # old non-parallel version
        for cidx, p in enumerate(config['parameters']):  # type: ignore
            for i in range(config['num_orbits']):  # type: ignore
                x[cidx][i] = generate_orbit(
                    num_pts_per_orbit=config['num_pts_per_orbit'],  # type: ignore
                    parameter=p
                    )
        """

        assert(not np.allclose(x[0,0], x[0,1]))

        # compute weak alpha persistence
        wap = WeakAlphaPersistence(
                            homology_dimensions=config['homology_dimensions'],
                            n_jobs=multiprocessing.cpu_count()
                            )
        # c: class, o: orbit, p: point, d: dimension
        x_stack = rearrange(x, 'c o p d -> (c o) p d')  # stack classes
        diagrams = wap.fit_transform(x_stack)
        # shape: (num_classes * n_samples, n_features, 3)

        diagrams = rearrange(
                                diagrams,
                                '(c o) p d -> c o p d',
                                c=config['num_classes']  # type: ignore
                            )

        # plot sample persistence diagrams
        if(False):
            plot_diagram(diagrams[1, 2])
            plot_diagram(diagrams[2, 2])

        # save dataset
        if dataset_type == 'train':
            with open(dgms_filename_, 'wb') as f:
                np.save(f, diagrams)
        else:
            with open(dgms_filename_validation, 'wb') as f:
                np.save(f, diagrams)

In [4]:
# load dataset
with open(dgms_filename, 'rb') as f:
    x = np.load(f)

# c: class, o: orbit, p: point in persistence diagram,
# d: coordinates + homology dimension
x = rearrange(
                x,
                'c o p d -> (c o) p d',
                c=config['num_classes']  # type: ignore
            )
# convert homology dimension to one-hot encoding
x = np.concatenate(
    (
        x[:, :, :2],
        (np.eye(config['num_homology_dimensions'])
         [x[:, :, -1].astype(np.int32)]),
    ),
    axis=-1)
# convert from [orbit, sequence_length, feature] to
# [orbit, feature, sequence_length] to fit to the
# input_shape of `SmallSetTransformer`
#x = rearrange(x, 'o s f -> o f s')

# generate labels
y_list = []
for i in range(config['num_classes']):  # type: ignore
    y_list += [i] * config['num_orbits']  # type: ignore

y = np.array(y_list)


# load dataset to PyTorch dataloader

x_tensor = torch.Tensor(x)
y_tensor = torch.Tensor(y)

dataset = TensorDataset(x_tensor, y_tensor)
dataloader = DataLoader(dataset,
                        shuffle=True,
                        batch_size=64,
                        num_workers=6,
                        pin_memory=True
                        )

In [10]:
# load dataset
with open(dgms_filename_validation, 'rb') as f:
    x = np.load(f)

# c: class, o: orbit, p: point in persistence diagram,
# d: coordinates + homology dimension
x = rearrange(
                x,
                'c o p d -> (c o) p d',
                c=config['num_classes']  # type: ignore
            )
# convert homology dimension to one-hot encoding
x = np.concatenate(
    (
        x[:, :, :2],
        (np.eye(config['num_homology_dimensions'])
         [x[:, :, -1].astype(np.int32)]),
    ),
    axis=-1)
# convert from [orbit, sequence_length, feature] to
# [orbit, feature, sequence_length] to fit to the
# input_shape of `SmallSetTransformer`
#x = rearrange(x, 'o s f -> o f s')

# generate labels
y_list = []
for i in range(config['num_classes']):  # type: ignore
    y_list += [i] * config['num_orbits']  # type: ignore

y = np.array(y_list)


# load dataset to PyTorch dataloader

x_tensor = torch.Tensor(x)
y_tensor = torch.Tensor(y)

dataset = TensorDataset(x_tensor, y_tensor)
dataloader_validation = DataLoader(dataset,
                        shuffle=True,
                        batch_size=64,
                        num_workers=6,
                        pin_memory=True
                        )

In [13]:
# initialize SmallSetTransformer model
"""model = SmallSetTransformer(
                            dim_input=4,
                            dim_out=64,
                            num_heads=4,
                            out_features=config['num_classes']  # type: ignore
                            )
"""

class SetTransformer(nn.Module):
    def __init__(
        self,
        dim_input=3,
        num_outputs=1,
        dim_output=40,
        num_inds=32,
        dim_hidden=128,
        num_heads=4,
        ln=False,
    ):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
            ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
            ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
        )
        self.dec = nn.Sequential(
            nn.Dropout(),
            PMA(dim_hidden, num_heads, num_outputs, ln=ln),
            nn.Dropout(),
            nn.Linear(dim_hidden, dim_output),
        )

    def forward(self, X):
        return self.dec(self.enc(X)).squeeze()

model = SetTransformer(dim_input=4, dim_output=5)

#print('model has', model.num_params(), 'trainable parameters.')

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")


In [17]:


def train(model, num_epochs: int = 10, lr: float = 1e-3,
          verbose: bool = False, compute_accuracy: bool = False):
    if use_cuda:
        model = nn.DataParallel(model)
        model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []

    # training loop
    for epoch in range(num_epochs):
        model.train()
        loss_per_epoch = 0
        for x_batch, y_batch in dataloader:
            # transfer to GPU
            if use_cuda:
                x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
            loss = criterion(model(x_batch), y_batch.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_per_epoch += loss.item()
            #print(loss.item())
        losses.append(loss_per_epoch)
        if verbose:
            print("epoch:", epoch, "loss:", loss_per_epoch)

        if compute_accuracy:
            # test accuracy
            correct = 0
            total = 0

            with torch.no_grad():
                for x_batch, y_batch in dataloader:
                    if use_cuda:
                            x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
                    outputs = model(x_batch).squeeze(1)
                    _, predictions = torch.max(outputs, 1)
                    total += y_batch.size(0)
                    correct += (predictions == y_batch).sum().item()

            print('Test accuracy of the network on the 5000 test diagrams: %8.2f %%' % (
                100 * correct / total))
            # validation accuracy
            correct = 0
            total = 0

            with torch.no_grad():
                for x_batch, y_batch in dataloader_validation:
                    if use_cuda:
                            x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
                    outputs = model(x_batch).squeeze(1)
                    _, predictions = torch.max(outputs, 1)
                    total += y_batch.size(0)
                    correct += (predictions == y_batch).sum().item()

            print('Validation accuracy of the network on the 5000 test diagrams: %8.2f %%' % (
                100 * correct / total))
    return losses

In [18]:
train(model, num_epochs=500, verbose=True, compute_accuracy=True)

epoch: 0 loss: 20.747889026999474
Test accuracy of the network on the 5000 test diagrams:    90.04 %
Validation accuracy of the network on the 5000 test diagrams:    89.62 %
epoch: 1 loss: 20.65145505219698
Test accuracy of the network on the 5000 test diagrams:    89.90 %
Validation accuracy of the network on the 5000 test diagrams:    88.82 %
epoch: 2 loss: 20.403223231434822
Test accuracy of the network on the 5000 test diagrams:    88.38 %
Validation accuracy of the network on the 5000 test diagrams:    87.36 %
epoch: 3 loss: 22.15877155959606
Test accuracy of the network on the 5000 test diagrams:    89.22 %
Validation accuracy of the network on the 5000 test diagrams:    88.60 %
epoch: 4 loss: 19.879891455173492
Test accuracy of the network on the 5000 test diagrams:    89.38 %
Validation accuracy of the network on the 5000 test diagrams:    89.02 %
epoch: 5 loss: 19.693836897611618
Test accuracy of the network on the 5000 test diagrams:    89.06 %
Validation accuracy of the netw

epoch: 47 loss: 18.850413415580988
Test accuracy of the network on the 5000 test diagrams:    89.82 %
Validation accuracy of the network on the 5000 test diagrams:    88.48 %
epoch: 48 loss: 18.658433236181736
Test accuracy of the network on the 5000 test diagrams:    90.68 %
Validation accuracy of the network on the 5000 test diagrams:    89.82 %
epoch: 49 loss: 19.136043176054955
Test accuracy of the network on the 5000 test diagrams:    87.42 %
Validation accuracy of the network on the 5000 test diagrams:    84.76 %
epoch: 50 loss: 19.72250735014677
Test accuracy of the network on the 5000 test diagrams:    90.38 %
Validation accuracy of the network on the 5000 test diagrams:    89.34 %
epoch: 51 loss: 19.370171554386616
Test accuracy of the network on the 5000 test diagrams:    90.36 %
Validation accuracy of the network on the 5000 test diagrams:    89.42 %
epoch: 52 loss: 19.215360939502716
Test accuracy of the network on the 5000 test diagrams:    90.50 %
Validation accuracy of t

epoch: 94 loss: 20.07527581602335
Test accuracy of the network on the 5000 test diagrams:    89.58 %
Validation accuracy of the network on the 5000 test diagrams:    89.26 %
epoch: 95 loss: 19.5769202709198
Test accuracy of the network on the 5000 test diagrams:    90.40 %
Validation accuracy of the network on the 5000 test diagrams:    90.32 %
epoch: 96 loss: 18.43622412532568
Test accuracy of the network on the 5000 test diagrams:    90.40 %
Validation accuracy of the network on the 5000 test diagrams:    89.66 %
epoch: 97 loss: 18.407309643924236
Test accuracy of the network on the 5000 test diagrams:    90.84 %
Validation accuracy of the network on the 5000 test diagrams:    90.12 %
epoch: 98 loss: 18.886694237589836
Test accuracy of the network on the 5000 test diagrams:    89.90 %
Validation accuracy of the network on the 5000 test diagrams:    89.46 %
epoch: 99 loss: 23.336259230971336
Test accuracy of the network on the 5000 test diagrams:    78.72 %
Validation accuracy of the 

epoch: 141 loss: 18.801198296248913
Test accuracy of the network on the 5000 test diagrams:    90.94 %
Validation accuracy of the network on the 5000 test diagrams:    89.56 %
epoch: 142 loss: 18.300055995583534
Test accuracy of the network on the 5000 test diagrams:    88.52 %
Validation accuracy of the network on the 5000 test diagrams:    87.72 %
epoch: 143 loss: 19.298768222332
Test accuracy of the network on the 5000 test diagrams:    90.56 %
Validation accuracy of the network on the 5000 test diagrams:    89.70 %
epoch: 144 loss: 19.15946315228939
Test accuracy of the network on the 5000 test diagrams:    90.06 %
Validation accuracy of the network on the 5000 test diagrams:    88.98 %
epoch: 145 loss: 18.553878232836723
Test accuracy of the network on the 5000 test diagrams:    90.92 %
Validation accuracy of the network on the 5000 test diagrams:    89.62 %
epoch: 146 loss: 19.482360631227493
Test accuracy of the network on the 5000 test diagrams:    90.34 %
Validation accuracy o

epoch: 188 loss: 18.194550059735775
Test accuracy of the network on the 5000 test diagrams:    90.78 %
Validation accuracy of the network on the 5000 test diagrams:    89.56 %
epoch: 189 loss: 19.049149490892887
Test accuracy of the network on the 5000 test diagrams:    90.70 %
Validation accuracy of the network on the 5000 test diagrams:    89.92 %
epoch: 190 loss: 19.516482021659613
Test accuracy of the network on the 5000 test diagrams:    86.62 %
Validation accuracy of the network on the 5000 test diagrams:    86.60 %
epoch: 191 loss: 19.907648339867592
Test accuracy of the network on the 5000 test diagrams:    89.92 %
Validation accuracy of the network on the 5000 test diagrams:    88.64 %
epoch: 192 loss: 18.591876413673162
Test accuracy of the network on the 5000 test diagrams:    90.26 %
Validation accuracy of the network on the 5000 test diagrams:    89.26 %
epoch: 193 loss: 19.77905486524105
Test accuracy of the network on the 5000 test diagrams:    90.56 %
Validation accurac

epoch: 235 loss: 18.939943946897984
Test accuracy of the network on the 5000 test diagrams:    90.56 %
Validation accuracy of the network on the 5000 test diagrams:    89.30 %
epoch: 236 loss: 17.941476605832577
Test accuracy of the network on the 5000 test diagrams:    91.16 %
Validation accuracy of the network on the 5000 test diagrams:    89.28 %
epoch: 237 loss: 18.40239456295967
Test accuracy of the network on the 5000 test diagrams:    90.66 %
Validation accuracy of the network on the 5000 test diagrams:    89.00 %
epoch: 238 loss: 18.302602469921112
Test accuracy of the network on the 5000 test diagrams:    91.08 %
Validation accuracy of the network on the 5000 test diagrams:    89.24 %
epoch: 239 loss: 19.17604151368141
Test accuracy of the network on the 5000 test diagrams:    90.42 %
Validation accuracy of the network on the 5000 test diagrams:    89.64 %
epoch: 240 loss: 18.392440997064114
Test accuracy of the network on the 5000 test diagrams:    90.34 %
Validation accuracy

epoch: 282 loss: 18.157440751791
Test accuracy of the network on the 5000 test diagrams:    91.54 %
Validation accuracy of the network on the 5000 test diagrams:    89.86 %
epoch: 283 loss: 18.450823694467545
Test accuracy of the network on the 5000 test diagrams:    90.50 %
Validation accuracy of the network on the 5000 test diagrams:    89.06 %
epoch: 284 loss: 18.43215185403824
Test accuracy of the network on the 5000 test diagrams:    91.06 %
Validation accuracy of the network on the 5000 test diagrams:    89.82 %
epoch: 285 loss: 19.10395280085504
Test accuracy of the network on the 5000 test diagrams:    90.72 %
Validation accuracy of the network on the 5000 test diagrams:    89.92 %
epoch: 286 loss: 18.73341454565525
Test accuracy of the network on the 5000 test diagrams:    90.52 %
Validation accuracy of the network on the 5000 test diagrams:    88.76 %
epoch: 287 loss: 18.207369446754456
Test accuracy of the network on the 5000 test diagrams:    91.06 %
Validation accuracy of 

epoch: 329 loss: 17.993426613509655
Test accuracy of the network on the 5000 test diagrams:    91.24 %
Validation accuracy of the network on the 5000 test diagrams:    90.02 %
epoch: 330 loss: 18.35368800163269
Test accuracy of the network on the 5000 test diagrams:    90.84 %
Validation accuracy of the network on the 5000 test diagrams:    89.68 %
epoch: 331 loss: 18.89242108911276
Test accuracy of the network on the 5000 test diagrams:    90.68 %
Validation accuracy of the network on the 5000 test diagrams:    89.42 %
epoch: 332 loss: 18.210813902318478
Test accuracy of the network on the 5000 test diagrams:    91.18 %
Validation accuracy of the network on the 5000 test diagrams:    89.72 %
epoch: 333 loss: 18.19181916117668
Test accuracy of the network on the 5000 test diagrams:    90.60 %
Validation accuracy of the network on the 5000 test diagrams:    89.40 %
epoch: 334 loss: 17.88386544585228
Test accuracy of the network on the 5000 test diagrams:    90.20 %
Validation accuracy o

epoch: 376 loss: 17.657096210867167
Test accuracy of the network on the 5000 test diagrams:    91.02 %
Validation accuracy of the network on the 5000 test diagrams:    89.46 %
epoch: 377 loss: 17.776838552206755
Test accuracy of the network on the 5000 test diagrams:    91.34 %
Validation accuracy of the network on the 5000 test diagrams:    89.66 %
epoch: 378 loss: 18.071385256014764
Test accuracy of the network on the 5000 test diagrams:    91.68 %
Validation accuracy of the network on the 5000 test diagrams:    89.90 %
epoch: 379 loss: 17.371812030673027
Test accuracy of the network on the 5000 test diagrams:    90.82 %
Validation accuracy of the network on the 5000 test diagrams:    89.66 %
epoch: 380 loss: 18.749385185539722
Test accuracy of the network on the 5000 test diagrams:    90.80 %
Validation accuracy of the network on the 5000 test diagrams:    89.40 %
epoch: 381 loss: 17.787607330828905
Test accuracy of the network on the 5000 test diagrams:    91.30 %
Validation accura

epoch: 423 loss: 17.8395382091403
Test accuracy of the network on the 5000 test diagrams:    90.62 %
Validation accuracy of the network on the 5000 test diagrams:    88.88 %
epoch: 424 loss: 18.322469972074032
Test accuracy of the network on the 5000 test diagrams:    91.22 %
Validation accuracy of the network on the 5000 test diagrams:    89.46 %
epoch: 425 loss: 17.808685697615147
Test accuracy of the network on the 5000 test diagrams:    91.06 %
Validation accuracy of the network on the 5000 test diagrams:    89.58 %
epoch: 426 loss: 17.39831894636154
Test accuracy of the network on the 5000 test diagrams:    90.84 %
Validation accuracy of the network on the 5000 test diagrams:    89.60 %
epoch: 427 loss: 17.154714815318584
Test accuracy of the network on the 5000 test diagrams:    90.84 %
Validation accuracy of the network on the 5000 test diagrams:    89.72 %
epoch: 428 loss: 17.549970537424088
Test accuracy of the network on the 5000 test diagrams:    90.30 %
Validation accuracy 

epoch: 470 loss: 17.453521476127207
Test accuracy of the network on the 5000 test diagrams:    91.74 %
Validation accuracy of the network on the 5000 test diagrams:    90.24 %
epoch: 471 loss: 19.274157628417015
Test accuracy of the network on the 5000 test diagrams:    82.36 %
Validation accuracy of the network on the 5000 test diagrams:    80.28 %
epoch: 472 loss: 20.165123268961906
Test accuracy of the network on the 5000 test diagrams:    90.86 %
Validation accuracy of the network on the 5000 test diagrams:    90.12 %
epoch: 473 loss: 17.33971268683672
Test accuracy of the network on the 5000 test diagrams:    91.12 %
Validation accuracy of the network on the 5000 test diagrams:    89.32 %
epoch: 474 loss: 17.664737723767757
Test accuracy of the network on the 5000 test diagrams:    90.36 %
Validation accuracy of the network on the 5000 test diagrams:    89.80 %
epoch: 475 loss: 17.348007008433342
Test accuracy of the network on the 5000 test diagrams:    90.62 %
Validation accurac

[20.747889026999474,
 20.65145505219698,
 20.403223231434822,
 22.15877155959606,
 19.879891455173492,
 19.693836897611618,
 19.675016447901726,
 19.363330580294132,
 19.44016182422638,
 18.9577427059412,
 19.897639617323875,
 20.536432534456253,
 19.505483224987984,
 19.21921593695879,
 20.123308464884758,
 19.498684279620647,
 19.730875335633755,
 19.30934865772724,
 18.915795024484396,
 20.589476868510246,
 20.251872450113297,
 19.54478994011879,
 21.2074057161808,
 19.810751482844353,
 19.04927945137024,
 19.38615544140339,
 19.275047704577446,
 19.48431959748268,
 19.619800619781017,
 19.519689172506332,
 19.426263887435198,
 20.351188600063324,
 19.604546420276165,
 20.618958920240402,
 22.48147886991501,
 25.309450939297676,
 20.21900037676096,
 19.108687222003937,
 19.23652637563646,
 19.176945492625237,
 19.64499705284834,
 18.67485513538122,
 18.681979089975357,
 19.130590957589447,
 19.176518999040127,
 19.152336910367012,
 19.355161927640438,
 18.850413415580988,
 18.658433

In [19]:
train(model, num_epochs=1000, verbose=True, compute_accuracy=True)

epoch: 0 loss: 18.68181872367859
Test accuracy of the network on the 5000 test diagrams:    90.62 %
Validation accuracy of the network on the 5000 test diagrams:    89.34 %
epoch: 1 loss: 18.330043144524097
Test accuracy of the network on the 5000 test diagrams:    90.82 %
Validation accuracy of the network on the 5000 test diagrams:    89.98 %
epoch: 2 loss: 19.28436778485775
Test accuracy of the network on the 5000 test diagrams:    91.14 %
Validation accuracy of the network on the 5000 test diagrams:    89.02 %
epoch: 3 loss: 18.061810217797756
Test accuracy of the network on the 5000 test diagrams:    91.18 %
Validation accuracy of the network on the 5000 test diagrams:    89.34 %
epoch: 4 loss: 17.87944643944502
Test accuracy of the network on the 5000 test diagrams:    90.80 %
Validation accuracy of the network on the 5000 test diagrams:    89.50 %
epoch: 5 loss: 17.6191086769104
Test accuracy of the network on the 5000 test diagrams:    91.14 %
Validation accuracy of the network

Test accuracy of the network on the 5000 test diagrams:    91.44 %
Validation accuracy of the network on the 5000 test diagrams:    88.92 %
epoch: 48 loss: 17.327013848349452
Test accuracy of the network on the 5000 test diagrams:    91.26 %
Validation accuracy of the network on the 5000 test diagrams:    89.70 %
epoch: 49 loss: 18.009168550372124
Test accuracy of the network on the 5000 test diagrams:    89.94 %
Validation accuracy of the network on the 5000 test diagrams:    88.84 %
epoch: 50 loss: 18.258909810334444
Test accuracy of the network on the 5000 test diagrams:    91.52 %
Validation accuracy of the network on the 5000 test diagrams:    89.88 %
epoch: 51 loss: 17.912146240472794
Test accuracy of the network on the 5000 test diagrams:    91.12 %
Validation accuracy of the network on the 5000 test diagrams:    89.30 %
epoch: 52 loss: 17.547801338136196
Test accuracy of the network on the 5000 test diagrams:    91.64 %
Validation accuracy of the network on the 5000 test diagra

Test accuracy of the network on the 5000 test diagrams:    91.46 %
Validation accuracy of the network on the 5000 test diagrams:    89.60 %
epoch: 95 loss: 16.80164085328579
Test accuracy of the network on the 5000 test diagrams:    91.24 %
Validation accuracy of the network on the 5000 test diagrams:    89.98 %
epoch: 96 loss: 17.071215510368347
Test accuracy of the network on the 5000 test diagrams:    91.62 %
Validation accuracy of the network on the 5000 test diagrams:    89.54 %
epoch: 97 loss: 17.112013585865498
Test accuracy of the network on the 5000 test diagrams:    91.24 %
Validation accuracy of the network on the 5000 test diagrams:    89.38 %
epoch: 98 loss: 17.316043724305928
Test accuracy of the network on the 5000 test diagrams:    91.52 %
Validation accuracy of the network on the 5000 test diagrams:    89.40 %
epoch: 99 loss: 16.949319828301668
Test accuracy of the network on the 5000 test diagrams:    91.56 %
Validation accuracy of the network on the 5000 test diagram

epoch: 141 loss: 17.62944955378771
Test accuracy of the network on the 5000 test diagrams:    91.80 %
Validation accuracy of the network on the 5000 test diagrams:    89.76 %
epoch: 142 loss: 17.813669480383396
Test accuracy of the network on the 5000 test diagrams:    90.90 %
Validation accuracy of the network on the 5000 test diagrams:    88.58 %
epoch: 143 loss: 17.121519830077887
Test accuracy of the network on the 5000 test diagrams:    91.32 %
Validation accuracy of the network on the 5000 test diagrams:    88.56 %
epoch: 144 loss: 17.9791152253747
Test accuracy of the network on the 5000 test diagrams:    91.26 %
Validation accuracy of the network on the 5000 test diagrams:    89.32 %
epoch: 145 loss: 17.554093247279525
Test accuracy of the network on the 5000 test diagrams:    91.86 %
Validation accuracy of the network on the 5000 test diagrams:    89.72 %
epoch: 146 loss: 16.764899250119925
Test accuracy of the network on the 5000 test diagrams:    91.38 %
Validation accuracy 

epoch: 188 loss: 17.612646736204624
Test accuracy of the network on the 5000 test diagrams:    89.40 %
Validation accuracy of the network on the 5000 test diagrams:    88.32 %
epoch: 189 loss: 18.259416840970516
Test accuracy of the network on the 5000 test diagrams:    90.94 %
Validation accuracy of the network on the 5000 test diagrams:    89.34 %
epoch: 190 loss: 18.112848110496998
Test accuracy of the network on the 5000 test diagrams:    90.82 %
Validation accuracy of the network on the 5000 test diagrams:    89.36 %
epoch: 191 loss: 17.29651341587305
Test accuracy of the network on the 5000 test diagrams:    91.58 %
Validation accuracy of the network on the 5000 test diagrams:    89.40 %
epoch: 192 loss: 16.791187673807144
Test accuracy of the network on the 5000 test diagrams:    91.18 %
Validation accuracy of the network on the 5000 test diagrams:    88.06 %
epoch: 193 loss: 16.857636395841837
Test accuracy of the network on the 5000 test diagrams:    91.58 %
Validation accurac

epoch: 235 loss: 16.719232968986034
Test accuracy of the network on the 5000 test diagrams:    91.12 %
Validation accuracy of the network on the 5000 test diagrams:    88.90 %
epoch: 236 loss: 16.629842095077038
Test accuracy of the network on the 5000 test diagrams:    91.66 %
Validation accuracy of the network on the 5000 test diagrams:    89.40 %
epoch: 237 loss: 16.484225124120712
Test accuracy of the network on the 5000 test diagrams:    91.78 %
Validation accuracy of the network on the 5000 test diagrams:    89.36 %
epoch: 238 loss: 16.43098515132442
Test accuracy of the network on the 5000 test diagrams:    91.44 %
Validation accuracy of the network on the 5000 test diagrams:    89.44 %
epoch: 239 loss: 17.091238752007484
Test accuracy of the network on the 5000 test diagrams:    91.38 %
Validation accuracy of the network on the 5000 test diagrams:    89.88 %
epoch: 240 loss: 17.343461267650127
Test accuracy of the network on the 5000 test diagrams:    91.12 %
Validation accurac

epoch: 282 loss: 16.057275729253888
Test accuracy of the network on the 5000 test diagrams:    91.60 %
Validation accuracy of the network on the 5000 test diagrams:    89.20 %
epoch: 283 loss: 16.65032623708248
Test accuracy of the network on the 5000 test diagrams:    91.76 %
Validation accuracy of the network on the 5000 test diagrams:    89.68 %
epoch: 284 loss: 17.152313247323036
Test accuracy of the network on the 5000 test diagrams:    91.12 %
Validation accuracy of the network on the 5000 test diagrams:    89.04 %
epoch: 285 loss: 16.93833038210869
Test accuracy of the network on the 5000 test diagrams:    91.30 %
Validation accuracy of the network on the 5000 test diagrams:    89.62 %
epoch: 286 loss: 25.409245274960995
Test accuracy of the network on the 5000 test diagrams:    87.46 %
Validation accuracy of the network on the 5000 test diagrams:    86.32 %
epoch: 287 loss: 21.268649391829967
Test accuracy of the network on the 5000 test diagrams:    91.36 %
Validation accuracy

KeyboardInterrupt: 

In [23]:
torch.save(model, 'PersFormer_ORBIT5K.pth')

In [24]:
torch.save(model.state_dict(), 'PersFormer_ORBIT5K_state_dict.pth')

number of trainable parameters: 291589


In [11]:
correct = 0
total = 0

with torch.no_grad():
    for x_batch, y_batch in dataloader_validation:
        if use_cuda:
                x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
        outputs = model(x_batch).squeeze(1)
        _, predictions = torch.max(outputs, 1)
        total += y_batch.size(0)
        correct += (predictions == y_batch).sum().item()

print('Accuracy of the network on the 5000 test diagrams: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 5000 test diagrams: 87 %


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [6]:
model = model.cuda()
x_batch, y_batch = next(iter(dataloader))
x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
y_pred = model(x_batch)

In [14]:
model.requires_grad()

ModuleAttributeError: 'SmallSetTransformer' object has no attribute 'requires_grad'

In [11]:
#del x_batch
#del y_batch

In [None]:
print('Validation accuracy of the network on the 5000 test diagrams: %2d %%' % (87.4))