In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import random

import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import Dataset, DataLoader

from torch.optim import Adam
from torch.nn import CrossEntropyLoss

import gudhi as gd

from tqdm import tqdm

torch.set_printoptions(precision=2, sci_mode=False, linewidth=100)
np.set_printoptions(precision=2, suppress=True, linewidth=100)

In [2]:
def gudhi_toarray(diagrams, replace_inf=True):
    diagram = np.array([[birth, death, dim] for (dim, (birth, death)) in diagrams])
    if replace_inf==True:
        diagram = np.nan_to_num(diagram, posinf=-np.inf)
        diagram = np.nan_to_num(diagram, neginf=np.max(diagram))
    return diagram

def diagram_reshape(diagram):
    zero_idx = np.where(diagram[:,2]==0)
    one_idx = np.where(diagram[:,2]==1)
    return diagram[zero_idx], diagram[one_idx]

In [3]:
def diagram(image, device, sublevel=True):
    # get height and square image
    h = int(np.sqrt(image.shape[0]))
    image_sq = image.reshape((h,h))

    # create complex
    cmplx = gd.CubicalComplex(dimensions=(h, h), top_dimensional_cells=(sublevel*image))

    # get pairs of critical simplices
    cmplx.compute_persistence()
    critical_pairs = cmplx.cofaces_of_persistence_pairs()
    
    # get essential critical pixel
    bpx0_essential = critical_pairs[1][0][0] // h, critical_pairs[1][0][0] % h

    # get critical pixels corresponding to critical simplices
    try:
        bpx0 = [critical_pairs[0][0][i][0] for i in range(len(critical_pairs[0][0]))]
        dpx0 = [critical_pairs[0][0][i][1] for i in range(len(critical_pairs[0][0]))]
    except IndexError:
        bpx0 = []
        dpx0 = []
        
    try:
        bpx1 = [critical_pairs[0][1][i][0] for i in range(len(critical_pairs[0][1]))]
        dpx1 = [critical_pairs[0][1][i][1] for i in range(len(critical_pairs[0][1]))]
    except IndexError:
        bpx1 = []
        dpx1 = []
    

    flat_image = image_sq.flatten()
    pd0_essential = torch.tensor([[image_sq[bpx0_essential], torch.max(image)]])

    if (len(bpx0)!=0):
        pdb0 = flat_image[bpx0][:, None]
        pdd0 = flat_image[dpx0][:, None]
        pd0 = torch.Tensor(torch.hstack([pdb0, pdd0]))
        pd0 = torch.vstack([pd0, pd0_essential.to(device)])
    else:
        pd0 = pd0_essential

    if (len(bpx1)!=0):
        pdb1 = flat_image[bpx1][:, None]
        pdd1 = flat_image[dpx1][:, None]
        pd1 = torch.Tensor(torch.hstack([pdb1, pdd1]))
    else:
        pd1 = torch.zeros((1, 2))
    
    return pd0, pd1

# Seminar 12: Topological data analysis of digital images

## MNIST
- sublevel filtration
- directional filtration
- PHT w/ directional filtration
- PHT w/ convolutional filtraion (fully differentiable)

#### Torch Dataset and collate_fn

In [4]:
class PersistenceDataset(Dataset):
    
    def __init__(self, diagrams, y):
        super().__init__()
        
        # get diagrams as list of tensors
        D = torch.ones([len(diagrams), max(map(len, diagrams))+1, 3]) * torch.inf  
        for i, dgm in enumerate(diagrams):
            D[i,:len(dgm)] = dgm

        # cut to the largest diagram accross all dataset
        max_len = torch.argmax(D[:,:,0], axis=1).max()
        D = D[:,:max_len+1] # leave at least one inf value!
            
        self.D = D
        self.y = y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.D[idx], self.y[idx]


def collate_fn(batch):

    # get len of a batch and len of each diagram in a batch
    n_batch = len(batch)
    d_lengths = [int(torch.argmax(D[:,0])) for D, y_ in batch]
    
    # set batch tensor to the max length of a diagram in a batch
    Ds = torch.ones([n_batch, max(d_lengths), 3]) * 0.
    D_masks = torch.zeros([n_batch, max(d_lengths)]).bool()
    ys = torch.zeros(n_batch).long()
    
    # populate diagrams, their masks, and targets
    for i, (D, y) in enumerate(batch):
        Ds[i][:d_lengths[i]] = D[:d_lengths[i]]
        D_masks[i][d_lengths[i]:] = True
        ys[i] = y
    
    return Ds, D_masks, ys

#### Model

In [5]:
class PersistentHomologyTransformer(nn.Module):
    
    def __init__(self, d_in=3, d_out=5, d_model=16, d_hidden=32, num_heads=2, num_layers=2, dropout=0.0):
        super().__init__()
        self.linear_in = nn.Linear(d_in, d_model)
        self.ln = nn.LayerNorm(d_in)
        el = nn.TransformerEncoderLayer(d_model, num_heads, d_hidden, dropout, batch_first=True, activation=F.gelu)
        self.encoder = nn.TransformerEncoder(el, num_layers)
        self.linear_out = nn.Linear(d_model, d_out)
        
    def _masked_mean(self, X, mask):
        X_masked = X * ~mask.unsqueeze(-1)
        n_masks = torch.sum(~mask, axis=1)
        X_masked_mean = torch.sum(X_masked, axis=1) / n_masks.unsqueeze(-1)
        return X_masked_mean
        
    def forward(self, X, mask):
        X = self.linear_in(self.ln(X))
        X = self.encoder(X, src_key_padding_mask=mask)
        X = self._masked_mean(X, mask)
        X = self.linear_out(X)
        return X

### Sublevel filtration

#### Data

In [6]:
transform = Compose([ToTensor(), Normalize(0.0, 1.0)])
dataset = MNIST("./data", train=False, download=True, transform=transform)
targets = dataset.targets

In [7]:
X_sublevel = []

for image, y_ in tqdm(dataset):

    # create cubical complex
    cmplx = gd.CubicalComplex(dimensions=image.shape, top_dimensional_cells=image.flatten())

    # get pairs of critical simplices
    cmplx.compute_persistence()
    X_sublevel.append(torch.tensor(gudhi_toarray(cmplx.persistence())))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:44<00:00, 225.64it/s]


#### Train

In [8]:
%%time
n_repeats = 1
n_epochs = 50
batch_size = 64
lr = 0.001

history = np.zeros((n_repeats, n_epochs, 3))
criterion = nn.CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # randomness
    seed = repeat_idx
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # random state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # data init
    dataset = PersistenceDataset(X_sublevel, targets)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    # model init
    model = PersistentHomologyTransformer(d_in=3, d_out=10)
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for X, mask, y in dataloader:
            loss_batch = criterion(model(X, mask), y)
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for X, mask, y in dataloader:
            y_hat = model(X, mask).argmax(dim=1)
            correct += int((y_hat == y).sum())
        accuracy_train = correct / len(dataloader.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")

  0 Loss   Acc   
  0 2.2827 0.1506
  1 2.2377 0.1620
  2 2.2237 0.1648
  3 2.2201 0.1806
  4 2.2128 0.1930
  5 2.1988 0.1882
  6 2.1886 0.2053
  7 2.1648 0.2168
  8 2.1408 0.2170
  9 2.1201 0.2387
 10 2.0994 0.2490
 11 2.0845 0.2513
 12 2.0583 0.2311
 13 2.0512 0.2318
 14 2.0356 0.2754
 15 2.0379 0.2533
 16 2.0038 0.2649
 17 1.9894 0.2117
 18 1.9931 0.2856
 19 1.9753 0.2918
 20 1.9810 0.2770
 21 1.9670 0.2847
 22 1.9665 0.2961
 23 1.9583 0.2926
 24 1.9553 0.2732
 25 1.9384 0.2876
 26 1.9471 0.2998
 27 1.9404 0.3006
 28 1.9342 0.3056
 29 1.9261 0.3053
 30 1.9097 0.3030
 31 1.9049 0.3133
 32 1.9060 0.2927
 33 1.9173 0.2933
 34 1.8988 0.2969
 35 1.8961 0.2854
 36 1.9121 0.2976
 37 1.8964 0.3113
 38 1.8805 0.3169
 39 1.8783 0.3099
 40 1.8907 0.2672
 41 1.8852 0.3218
 42 1.8718 0.3135
 43 1.8929 0.3125
 44 1.8781 0.3061
 45 1.8811 0.2997
 46 1.8686 0.3105
 47 1.8640 0.3213
 48 1.8808 0.3014
 49 1.8715 0.3106

CPU times: user 4min 12s, sys: 28.6 s, total: 4min 41s
Wall time: 1min 11s


### Direction filter

In [9]:
dataset = MNIST("./data", train=False, download=True, transform=transform)
diagrams_all_dir = pickle.load(open("./data/MNIST_D_test_dir.pkl", "rb"))
targets = dataset.targets

In [10]:
direction = 0

diagrams_dir = []
for diagram_dir in diagrams_all_dir:
    dir_idx = diagram_dir[:,-1] == direction
    diagrams_dir.append(diagram_dir[dir_idx,:-2])

In [11]:
def collate_fn(batch):

    # get len of a batch and len of each diagram in a batch
    n_batch = len(batch)
    d_lengths = [int(torch.argmax(D[:,0])) for D, y_ in batch]
    
    # set batch tensor to the max length of a diagram in a batch
    Ds = torch.ones([n_batch, max(d_lengths), 3]) * 0.
    D_masks = torch.zeros([n_batch, max(d_lengths)]).bool()
    ys = torch.zeros(n_batch).long()
    
    # populate diagrams, their masks, and targets
    for i, (D, y) in enumerate(batch):
        Ds[i][:d_lengths[i]] = D[:d_lengths[i]]
        D_masks[i][d_lengths[i]:] = True
        ys[i] = y
        
    # masked normalize
    for i, (D, y) in enumerate(batch):
        Ds[i][~D_masks[i],0] = (Ds[i][~D_masks[i],0] - 1.0662154048380699) / 0.48181154844016033
        Ds[i][~D_masks[i],1] = (Ds[i][~D_masks[i],1] - 1.4032599645792931) / 0.3154062619965701
        Ds[i][~D_masks[i],2] = (Ds[i][~D_masks[i],2] - 0.7567565129537418) / 0.4290408990479055
    
    return Ds, D_masks, ys

#### Train

In [12]:
%%time
n_repeats = 1
n_epochs = 50
batch_size = 64
lr = 0.001

history = np.zeros((n_repeats, n_epochs, 3))
criterion = nn.CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # randomness
    seed = repeat_idx
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # random state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # data init
    dataset = PersistenceDataset(diagrams_dir, targets)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    # model init
    model = PersistentHomologyTransformer(d_in=3, d_out=10)
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for X, mask, y in dataloader:
            loss_batch = criterion(model(X, mask), y)
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for X, mask, y in dataloader:
            y_hat = model(X, mask).argmax(dim=1)
            correct += int((y_hat == y).sum())
        accuracy_train = correct / len(dataloader.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")

  0 Loss   Acc   
  0 2.2007 0.2471
  1 1.8998 0.3478
  2 1.7431 0.4015
  3 1.6303 0.4446
  4 1.5513 0.4651
  5 1.5060 0.4704
  6 1.4723 0.4787
  7 1.4572 0.4694
  8 1.4448 0.4787
  9 1.4310 0.4706
 10 1.4211 0.4898
 11 1.4156 0.4921
 12 1.4048 0.4879
 13 1.3969 0.4950
 14 1.3953 0.5041
 15 1.3888 0.5079
 16 1.3853 0.5058
 17 1.3847 0.5034
 18 1.3745 0.5124
 19 1.3655 0.5122
 20 1.3661 0.5095
 21 1.3583 0.5098
 22 1.3603 0.5111
 23 1.3578 0.5080
 24 1.3595 0.5146
 25 1.3421 0.5164
 26 1.3488 0.5137
 27 1.3455 0.5053
 28 1.3453 0.5118
 29 1.3402 0.5223
 30 1.3386 0.5239
 31 1.3369 0.5199
 32 1.3343 0.5252
 33 1.3295 0.5246
 34 1.3314 0.5270
 35 1.3301 0.5247
 36 1.3211 0.5217
 37 1.3258 0.5165
 38 1.3187 0.5293
 39 1.3179 0.5271
 40 1.3198 0.5192
 41 1.3106 0.5205
 42 1.3141 0.5221
 43 1.3140 0.5236
 44 1.3053 0.5171
 45 1.3060 0.5191
 46 1.3056 0.5337
 47 1.3100 0.5324
 48 1.3029 0.5335
 49 1.2990 0.5228

CPU times: user 8min 10s, sys: 1min 10s, total: 9min 21s
Wall time: 2min 22s


## Persistence Homology Transform

### Direction filter

#### Dataset

In [13]:
class PersistenceTransformDataset(Dataset):
    
    def __init__(self, diagrams, y, idx=None, eps=None):
        super().__init__()
        
        # get diagrams as list of tensors
        D = torch.ones([len(diagrams), max(map(len, diagrams))+1, 4]) * torch.inf

        # select points according to eps and idx
        for i, dgm in enumerate(diagrams):

            # eps
            if eps is not None:
                eps_idx = (dgm[:,1] - dgm[:,0]) >= eps
                dgm = dgm[eps_idx]

            # idx
            if idx is not None:
                dgm_idx = torch.isin(dgm[:,-1], idx)
                dgm = dgm[dgm_idx]
                D[i,:len(dgm)] = dgm[:,:-1]
            else:
                D[i,:len(dgm)] = dgm[:,:-1]

        # cut to the largest diagram accross all dataset
        if idx is not None:
            max_len = torch.argmax(D[:,:,0], axis=1).max()
            D = D[:,:max_len+1] # leave at least one inf value!
            
        self.D = D
        self.y = y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.D[idx], self.y[idx]
    
    
def collate_fn(batch):

    # get len of a batch and len of each diagram in a batch
    n_batch = len(batch)
    d_lengths = [int(torch.argmax(D[:,0])) for D, y_ in batch]
    
    # set batch tensor to the max length of a diagram in a batch
    Ds = torch.ones([n_batch, max(d_lengths), 4]) * 0.
    D_masks = torch.zeros([n_batch, max(d_lengths)]).bool()
    ys = torch.zeros(n_batch).long()
    
    # populate diagrams, their masks, and targets
    for i, (D, y) in enumerate(batch):
        Ds[i][:d_lengths[i]] = D[:d_lengths[i]]
        D_masks[i][d_lengths[i]:] = True
        ys[i] = y
        
    # masked normalize
    for i, (D, y) in enumerate(batch):
        Ds[i][~D_masks[i],0] = (Ds[i][~D_masks[i],0] - 1.0662154048380699) / 0.48181154844016033
        Ds[i][~D_masks[i],1] = (Ds[i][~D_masks[i],1] - 1.4032599645792931) / 0.3154062619965701
        Ds[i][~D_masks[i],2] = (Ds[i][~D_masks[i],2] - 0.7567565129537418) / 0.4290408990479055
        Ds[i][~D_masks[i],3] = (Ds[i][~D_masks[i],3] - 11.803894241177744) / 11.236356222975049
    
    return Ds, D_masks, ys

#### Train

In [14]:
%%time
n_repeats = 1
n_epochs = 50
batch_size = 64
lr = 0.001

history = np.zeros((n_repeats, n_epochs, 3))
criterion = nn.CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # randomness
    seed = repeat_idx
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # random state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # data init
    dataset = PersistenceTransformDataset(diagrams_all_dir, targets, idx=torch.tensor([0, 3, 6]), eps=0.0)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    # model init
    model = PersistentHomologyTransformer(d_in=4, d_out=10) # increase input dim by 1
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for X, mask, y in dataloader:
            loss_batch = criterion(model(X, mask), y)
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for X, mask, y in dataloader:
            y_hat = model(X, mask).argmax(dim=1)
            correct += int((y_hat == y).sum())
        accuracy_train = correct / len(dataloader.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")

  0 Loss   Acc   
  0 2.1943 0.2571
  1 1.7395 0.4293
  2 1.4065 0.5139
  3 1.2606 0.5548
  4 1.1725 0.5959
  5 1.1093 0.6001
  6 1.0515 0.6355
  7 1.0418 0.6268
  8 0.9912 0.6499
  9 0.9617 0.6485
 10 0.9524 0.6657
 11 0.9323 0.6792
 12 0.9368 0.6563
 13 0.9022 0.6548
 14 0.8900 0.6798
 15 0.8857 0.6697
 16 0.8697 0.6860
 17 0.8681 0.6848
 18 0.8609 0.6908
 19 0.8501 0.6695
 20 0.8458 0.7002
 21 0.8313 0.6997
 22 0.8379 0.7089
 23 0.8205 0.7099
 24 0.8093 0.6804
 25 0.8084 0.7175
 26 0.8138 0.7182
 27 0.7930 0.7234
 28 0.7891 0.6980
 29 0.7968 0.7319
 30 0.7726 0.7239
 31 0.7680 0.7162
 32 0.7766 0.7298
 33 0.7618 0.7079
 34 0.7531 0.7292
 35 0.7581 0.7303
 36 0.7482 0.7243
 37 0.7377 0.7222
 38 0.7307 0.7354
 39 0.7311 0.7313
 40 0.7369 0.7205
 41 0.7203 0.7508
 42 0.7145 0.7427
 43 0.7190 0.7611
 44 0.6955 0.7685
 45 0.6976 0.7578
 46 0.6999 0.7411
 47 0.6949 0.7573
 48 0.6891 0.7534
 49 0.6729 0.7589

CPU times: user 14min 31s, sys: 1min 28s, total: 16min
Wall time: 4min 4s


### Convolutional filter

#### Differentiability of persistent homology

In [15]:
dataset_diff = MNIST("./data", train=True, download=True, transform=transform)
dataloader_diff = DataLoader(dataset_diff, batch_size=2, shuffle=True)

In [16]:
images, _ = next(iter(dataloader_diff))
images.numpy()

array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],


       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]]], dtype=float32)

In [17]:
# create complex
cmplx = gd.CubicalComplex(dimensions=(28, 28), top_dimensional_cells=(images[0].flatten()))

# reduce boundary matrix
cmplx.compute_persistence()

In [18]:
# critical pixels
critical_pairs = cmplx.cofaces_of_persistence_pairs()
critical_pairs

[[array([], shape=(0, 2), dtype=int64),
  array([[158, 159],
         [317, 289],
         [325, 297],
         [381, 353],
         [489, 488],
         [214, 186],
         [298, 242],
         [372, 344],
         [429, 400],
         [436, 408],
         [487, 458],
         [490, 462],
         [463, 464],
         [546, 518],
         [602, 574],
         [687, 630]])],
 [array([783])]]

In [19]:
# persistence diagram
pd = cmplx.persistence()
pd

[(1, (0.0, 1.0)),
 (1, (0.7568627595901489, 0.9882352948188782)),
 (1, (0.95686274766922, 0.9960784316062927)),
 (1, (0.9882352948188782, 0.9960784316062927)),
 (1, (0.9882352948188782, 0.9960784316062927)),
 (1, (0.9921568632125854, 1.0)),
 (1, (0.9882352948188782, 0.9921568632125854)),
 (1, (0.9921568632125854, 0.9960784316062927)),
 (1, (0.9882352948188782, 0.9921568632125854)),
 (1, (0.9921568632125854, 0.9960784316062927)),
 (1, (0.9882352948188782, 0.9921568632125854)),
 (1, (0.9921568632125854, 0.9960784316062927)),
 (1, (0.9921568632125854, 0.9960784316062927)),
 (1, (0.9921568632125854, 0.9960784316062927)),
 (1, (0.9921568632125854, 0.9960784316062927)),
 (1, (0.9882352948188782, 0.9921568632125854)),
 (0, (0.0, inf))]

In [20]:
def persistence_diagram(image):

    h, w = image.shape
    img_flat = image.flatten()

    ccomplex = gd.CubicalComplex(
        dimensions = (h, w), 
        top_dimensional_cells=img_flat
    )
    
    # get pairs of critical simplices
    ccomplex.compute_persistence()
    critical_pairs = ccomplex.cofaces_of_persistence_pairs()

    # get essential critical pixels (never vanish)
    essential_features = critical_pairs[1][0]

    # 0-homology image critical pixels
    try:
        critical_pairs_0 = critical_pairs[0][0]
    except:
        critical_pairs_0 = np.empty((0, 2))
    critical_0_ver_ind = critical_pairs_0 // w
    critical_0_hor_ind = critical_pairs_0 % w
    critical_pixels_0 = np.stack([critical_0_ver_ind, critical_0_hor_ind], axis=2)

    # 0-homology essential pixels (ends with last added pixel)
    last_pixel = torch.argmax(image).item()
    essential_pixels_0 = np.array([[essential_features[0] // w, essential_features[0] % w], [last_pixel // w, last_pixel % 4]])[np.newaxis, ...]
    critical_pixels_0 = np.vstack([critical_pixels_0, essential_pixels_0])

    # 0-homology persistance diagram
    pd0 = image[critical_pixels_0[:, :, 0].flatten(), critical_pixels_0[:, :, 1].flatten()].reshape((critical_pixels_0.shape[0], 2))

    # 1-homology image critical pixels
    try:
        critical_pairs_1 = critical_pairs[0][1]
    except:
        critical_pairs_1 = np.empty((0, 2))
    critical_1_ver_ind = critical_pairs_1 // w
    critical_1_hor_ind = critical_pairs_1 % w
    critical_pixels_1 = np.stack([critical_1_ver_ind, critical_1_hor_ind], axis=2)

    # 1-homology persistance diagram
    pd1 = image[critical_pixels_1[:, :, 0].flatten(), critical_pixels_1[:, :, 1].flatten()].reshape((critical_pixels_1.shape[0], 2))

    return pd0, pd1

In [21]:
class ImagePersistence(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3))
        
    def forward(self, X):
        X_conv = self.conv(X)
        
        peristence_diagrams = []
        for i, x_conv in enumerate(X_conv):
            pd = persistence_diagram(x_conv[0])
            peristence_diagrams.append(pd)
        
        return peristence_diagrams

In [22]:
persistent_homology = ImagePersistence()
persistent_homology(images)

[(tensor([[0.10, 0.11],
          [0.16, 0.18],
          [0.17, 0.21],
          [0.06, 0.23],
          [0.18, 0.23],
          [0.08, 0.23],
          [0.18, 0.23],
          [0.02, 0.23]], grad_fn=<ViewBackward0>),
  tensor([[0.23, 0.25],
          [0.20, 0.27],
          [0.50, 0.50],
          [0.50, 0.51],
          [0.45, 0.52],
          [0.51, 0.53],
          [0.54, 0.55],
          [0.50, 0.56],
          [0.49, 0.58],
          [0.54, 0.67],
          [0.67, 0.76],
          [0.65, 0.78],
          [0.23, 0.80]], grad_fn=<ViewBackward0>)),
 (tensor([[ 0.10,  0.13],
          [ 0.13,  0.14],
          [ 0.07,  0.15],
          [ 0.15,  0.16],
          [ 0.13,  0.16],
          [ 0.17,  0.19],
          [ 0.16,  0.19],
          [ 0.15,  0.19],
          [ 0.16,  0.20],
          [ 0.05,  0.42],
          [-0.04,  0.23]], grad_fn=<ViewBackward0>),
  tensor([[0.20, 0.21],
          [0.19, 0.21],
          [0.21, 0.21],
          [0.19, 0.22],
          [0.23, 0.23],
        

#### Model

In [23]:
class ConvDiagram(nn.Module):
    def __init__(self, device):
        super(ConvDiagram, self).__init__()
        self.device = device
        
    def forward(self, x):
        diagrams = []
        for i in range(x.shape[0]):
            res = diagram(x[i].flatten(), self.device)
            for j in range(len(res)):
                diagrams.append(torch.concatenate([res[j], torch.Tensor([[j, i] for _ in range(res[j].shape[0])]).to(self.device)], axis=1))
        diagrams = torch.concatenate(diagrams)
        return diagrams


class Transformer(torch.nn.Module):
    def __init__(self, n_in, n_hidden, n_out, seq_size=1024, nhead=2, num_layers=2, dim_feedforward=16):
        super(Transformer, self).__init__()
        self.embeddings = nn.Linear(n_in, n_hidden)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=n_hidden, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(seq_size, n_out)

    def forward(self, X):
        X = self.embeddings(X)
        X = self.transformer(X)
        X = X.mean(dim=-1)
        X = self.classifier(X)
        X = X.softmax(dim=-1)
        return X


class TopologicalConvTransformer(nn.Module):
    def __init__(self, n_in, n_conv, max_sequence, n_diag, n_hidden, n_out, nhead=2, num_layers=2, dim_feedforward=16, device='cuda'):
        super(TopologicalConvTransformer, self).__init__()
        
        self.max_sequence = max_sequence
        self.conv1 = nn.Conv2d(n_in, n_conv, 3)
        self.conv2 = nn.Conv2d(n_conv, n_conv, 3)
        self.conv3 = nn.Conv2d(n_conv, n_conv, 3)
        self.bn1 = nn.BatchNorm2d(n_conv)
        self.bn2 = nn.BatchNorm2d(n_conv)
        self.bn3 = nn.BatchNorm2d(n_conv)
        self.diagram = ConvDiagram(device)
        self.transformer = Transformer(n_diag, n_hidden, n_out, max_sequence, nhead, num_layers, dim_feedforward)

    def forward(self, xs):
        result = []
        for i in range(xs.shape[0]):
            x = xs[i][None, :, :] / 256
            x = self.bn1(F.gelu(self.conv1(x)))
            #x = self.bn2(F.gelu(self.conv2(x)))
            #x = self.bn3(F.gelu(self.conv3(x)))
            x = self.diagram(x)
            if x.shape[0] > self.max_sequence:
                x = x[:self.max_sequence]
            x = F.pad(x, (0, 0, 0, self.max_sequence - x.shape[0]), "constant", 0)
            x = self.transformer(x)
            result.append(x[None, :])
        result = torch.concatenate(result, axis=0)
        return result

In [24]:
kwargs = {"n_in": 1,
 "n_conv": 1,
 "max_sequence": 64,
 "n_diag": 4,
 "n_hidden": 32, "n_out": 10, "nhead": 2, "num_layers": 2, "dim_feedforward": 16, "device": "cpu"}

In [25]:
model = TopologicalConvTransformer(**kwargs)
model(images)[0]

tensor([0.10, 0.10, 0.10, 0.11, 0.10, 0.10, 0.10, 0.11, 0.09, 0.09], grad_fn=<SelectBackward0>)

In [26]:
%%time
n_repeats = 1
n_epochs = 20
batch_size = 64
lr = 0.002

history = np.zeros((n_repeats, n_epochs, 3))
criterion = CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # randomness
    seed = repeat_idx
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # random state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # data init
    dataset_conv = MNIST("./data", train=False, download=True, transform=transform)
    dataloader_conv = DataLoader(dataset_conv, batch_size=batch_size, shuffle=True)
    
    # model init
    model = TopologicalConvTransformer(**kwargs)
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for batch in dataloader_conv:
            loss_batch = criterion(model(batch[0]), batch[1])
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for batch in dataloader_conv:
            y_hat = model(batch[0]).argmax(dim=1)
            correct += int((y_hat == batch[1]).sum())
        accuracy_train = correct / len(dataloader_conv.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")

  0 Loss   Acc   


KeyboardInterrupt: 