In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import random

import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import Dataset, DataLoader

from torch.optim import Adam
from torch.nn import CrossEntropyLoss

import gudhi as gd

from tqdm import tqdm

torch.set_printoptions(precision=2, sci_mode=False, linewidth=100)
np.set_printoptions(precision=2, suppress=True, linewidth=100)

In [None]:
def gudhi_toarray(diagrams, replace_inf=True):
    diagram = np.array([[birth, death, dim] for (dim, (birth, death)) in diagrams])
    if replace_inf==True:
        diagram = np.nan_to_num(diagram, posinf=-np.inf)
        diagram = np.nan_to_num(diagram, neginf=np.max(diagram))
    return diagram

def diagram_reshape(diagram):
    zero_idx = np.where(diagram[:,2]==0)
    one_idx = np.where(diagram[:,2]==1)
    return diagram[zero_idx], diagram[one_idx]

In [None]:
def diagram(image, device, sublevel=True):
    # get height and square image
    h = int(np.sqrt(image.shape[0]))
    image_sq = image.reshape((h,h))

    # create complex
    cmplx = gd.CubicalComplex(dimensions=(h, h), top_dimensional_cells=(sublevel*image))

    # get pairs of critical simplices
    cmplx.compute_persistence()
    critical_pairs = cmplx.cofaces_of_persistence_pairs()
    
    # get essential critical pixel
    bpx0_essential = critical_pairs[1][0][0] // h, critical_pairs[1][0][0] % h

    # get critical pixels corresponding to critical simplices
    try:
        bpx0 = [critical_pairs[0][0][i][0] for i in range(len(critical_pairs[0][0]))]
        dpx0 = [critical_pairs[0][0][i][1] for i in range(len(critical_pairs[0][0]))]
    except IndexError:
        bpx0 = []
        dpx0 = []
        
    try:
        bpx1 = [critical_pairs[0][1][i][0] for i in range(len(critical_pairs[0][1]))]
        dpx1 = [critical_pairs[0][1][i][1] for i in range(len(critical_pairs[0][1]))]
    except IndexError:
        bpx1 = []
        dpx1 = []
    

    flat_image = image_sq.flatten()
    pd0_essential = torch.tensor([[image_sq[bpx0_essential], torch.max(image)]])

    if (len(bpx0)!=0):
        pdb0 = flat_image[bpx0][:, None]
        pdd0 = flat_image[dpx0][:, None]
        pd0 = torch.Tensor(torch.hstack([pdb0, pdd0]))
        pd0 = torch.vstack([pd0, pd0_essential.to(device)])
    else:
        pd0 = pd0_essential

    if (len(bpx1)!=0):
        pdb1 = flat_image[bpx1][:, None]
        pdd1 = flat_image[dpx1][:, None]
        pd1 = torch.Tensor(torch.hstack([pdb1, pdd1]))
    else:
        pd1 = torch.zeros((1, 2))
    
    return pd0, pd1

# Seminar 12: Topological data analysis of digital images

#### Torch Dataset and collate_fn

In [None]:
class PersistenceDataset(Dataset):
    
    def __init__(self, diagrams, y):
        super().__init__()
        
        # get diagrams as list of tensors
        D = torch.ones([len(diagrams), max(map(len, diagrams))+1, 3]) * torch.inf  
        for i, dgm in enumerate(diagrams):
            D[i,:len(dgm)] = dgm

        # cut to the largest diagram accross all dataset
        max_len = torch.argmax(D[:,:,0], axis=1).max()
        D = D[:,:max_len+1] # leave at least one inf value!
            
        self.D = D
        self.y = y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.D[idx], self.y[idx]


def collate_fn(batch):

    # get len of a batch and len of each diagram in a batch
    n_batch = len(batch)
    d_lengths = [int(torch.argmax(D[:,0])) for D, y_ in batch]
    
    # set batch tensor to the max length of a diagram in a batch
    Ds = torch.ones([n_batch, max(d_lengths), 3]) * 0.
    D_masks = torch.zeros([n_batch, max(d_lengths)]).bool()
    ys = torch.zeros(n_batch).long()
    
    # populate diagrams, their masks, and targets
    for i, (D, y) in enumerate(batch):
        Ds[i][:d_lengths[i]] = D[:d_lengths[i]]
        D_masks[i][d_lengths[i]:] = True
        ys[i] = y
    
    return Ds, D_masks, ys

#### Model

In [None]:
class PersistentHomologyTransformer(nn.Module):
    
    def __init__(self, d_in=3, d_out=5, d_model=16, d_hidden=32, num_heads=2, num_layers=2, dropout=0.0):
        super().__init__()
        self.linear_in = nn.Linear(d_in, d_model)
        self.ln = nn.LayerNorm(d_in)
        el = nn.TransformerEncoderLayer(d_model, num_heads, d_hidden, dropout, batch_first=True, activation=F.gelu)
        self.encoder = nn.TransformerEncoder(el, num_layers)
        self.linear_out = nn.Linear(d_model, d_out)
        
    def _masked_mean(self, X, mask):
        X_masked = X * ~mask.unsqueeze(-1)
        n_masks = torch.sum(~mask, axis=1)
        X_masked_mean = torch.sum(X_masked, axis=1) / n_masks.unsqueeze(-1)
        return X_masked_mean
        
    def forward(self, X, mask):
        X = self.linear_in(self.ln(X))
        X = self.encoder(X, src_key_padding_mask=mask)
        X = self._masked_mean(X, mask)
        X = self.linear_out(X)
        return X

### Sublevel filtration

#### Data

In [None]:
transform = Compose([ToTensor(), Normalize(0.0, 1.0)])
dataset = MNIST("./data", train=False, download=True, transform=transform)
targets = dataset.targets

In [None]:
X_sublevel = []

for image, y_ in tqdm(dataset):

    # create cubical complex
    cmplx = gd.CubicalComplex(dimensions=image.shape, top_dimensional_cells=image.flatten())

    # get pairs of critical simplices
    cmplx.compute_persistence()
    X_sublevel.append(torch.tensor(gudhi_toarray(cmplx.persistence())))

#### Train

In [None]:
%%time
n_repeats = 1
n_epochs = 50
batch_size = 64
lr = 0.001

history = np.zeros((n_repeats, n_epochs, 3))
criterion = nn.CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # randomness
    seed = repeat_idx
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # random state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # data init
    dataset = PersistenceDataset(X_sublevel, targets)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    # model init
    model = PersistentHomologyTransformer(d_in=3, d_out=10)
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for X, mask, y in dataloader:
            loss_batch = criterion(model(X, mask), y)
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for X, mask, y in dataloader:
            y_hat = model(X, mask).argmax(dim=1)
            correct += int((y_hat == y).sum())
        accuracy_train = correct / len(dataloader.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")

### Direction filter

In [None]:
dataset = MNIST("./data", train=False, download=True, transform=transform)
diagrams_all_dir = pickle.load(open("./data/MNIST_D_test_dir.pkl", "rb"))
targets = dataset.targets

In [None]:
direction = 0

diagrams_dir = []
for diagram_dir in diagrams_all_dir:
    dir_idx = diagram_dir[:,-1] == direction
    diagrams_dir.append(diagram_dir[dir_idx,:-2])

In [None]:
def collate_fn(batch):

    # get len of a batch and len of each diagram in a batch
    n_batch = len(batch)
    d_lengths = [int(torch.argmax(D[:,0])) for D, y_ in batch]
    
    # set batch tensor to the max length of a diagram in a batch
    Ds = torch.ones([n_batch, max(d_lengths), 3]) * 0.
    D_masks = torch.zeros([n_batch, max(d_lengths)]).bool()
    ys = torch.zeros(n_batch).long()
    
    # populate diagrams, their masks, and targets
    for i, (D, y) in enumerate(batch):
        Ds[i][:d_lengths[i]] = D[:d_lengths[i]]
        D_masks[i][d_lengths[i]:] = True
        ys[i] = y
        
    # masked normalize
    for i, (D, y) in enumerate(batch):
        Ds[i][~D_masks[i],0] = (Ds[i][~D_masks[i],0] - 1.0662154048380699) / 0.48181154844016033
        Ds[i][~D_masks[i],1] = (Ds[i][~D_masks[i],1] - 1.4032599645792931) / 0.3154062619965701
        Ds[i][~D_masks[i],2] = (Ds[i][~D_masks[i],2] - 0.7567565129537418) / 0.4290408990479055
    
    return Ds, D_masks, ys

#### Train

In [None]:
%%time
n_repeats = 1
n_epochs = 50
batch_size = 64
lr = 0.001

history = np.zeros((n_repeats, n_epochs, 3))
criterion = nn.CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # randomness
    seed = repeat_idx
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # random state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # data init
    dataset = PersistenceDataset(diagrams_dir, targets)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    # model init
    model = PersistentHomologyTransformer(d_in=3, d_out=10)
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for X, mask, y in dataloader:
            loss_batch = criterion(model(X, mask), y)
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for X, mask, y in dataloader:
            y_hat = model(X, mask).argmax(dim=1)
            correct += int((y_hat == y).sum())
        accuracy_train = correct / len(dataloader.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")

## Persistence Homology Transform

### Direction filter

#### Dataset

In [None]:
class PersistenceTransformDataset(Dataset):
    
    def __init__(self, diagrams, y, idx=None, eps=None):
        super().__init__()
        
        # get diagrams as list of tensors
        D = torch.ones([len(diagrams), max(map(len, diagrams))+1, 4]) * torch.inf

        # select points according to eps and idx
        for i, dgm in enumerate(diagrams):

            # eps
            if eps is not None:
                eps_idx = (dgm[:,1] - dgm[:,0]) >= eps
                dgm = dgm[eps_idx]

            # idx
            if idx is not None:
                dgm_idx = torch.isin(dgm[:,-1], idx)
                dgm = dgm[dgm_idx]
                D[i,:len(dgm)] = dgm[:,:-1]
            else:
                D[i,:len(dgm)] = dgm[:,:-1]

        # cut to the largest diagram accross all dataset
        if idx is not None:
            max_len = torch.argmax(D[:,:,0], axis=1).max()
            D = D[:,:max_len+1] # leave at least one inf value!
            
        self.D = D
        self.y = y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.D[idx], self.y[idx]
    
    
def collate_fn(batch):

    # get len of a batch and len of each diagram in a batch
    n_batch = len(batch)
    d_lengths = [int(torch.argmax(D[:,0])) for D, y_ in batch]
    
    # set batch tensor to the max length of a diagram in a batch
    Ds = torch.ones([n_batch, max(d_lengths), 4]) * 0.
    D_masks = torch.zeros([n_batch, max(d_lengths)]).bool()
    ys = torch.zeros(n_batch).long()
    
    # populate diagrams, their masks, and targets
    for i, (D, y) in enumerate(batch):
        Ds[i][:d_lengths[i]] = D[:d_lengths[i]]
        D_masks[i][d_lengths[i]:] = True
        ys[i] = y
        
    # masked normalize
    for i, (D, y) in enumerate(batch):
        Ds[i][~D_masks[i],0] = (Ds[i][~D_masks[i],0] - 1.0662154048380699) / 0.48181154844016033
        Ds[i][~D_masks[i],1] = (Ds[i][~D_masks[i],1] - 1.4032599645792931) / 0.3154062619965701
        Ds[i][~D_masks[i],2] = (Ds[i][~D_masks[i],2] - 0.7567565129537418) / 0.4290408990479055
        Ds[i][~D_masks[i],3] = (Ds[i][~D_masks[i],3] - 11.803894241177744) / 11.236356222975049
    
    return Ds, D_masks, ys

#### Train

In [None]:
%%time
n_repeats = 1
n_epochs = 50
batch_size = 64
lr = 0.001

history = np.zeros((n_repeats, n_epochs, 3))
criterion = nn.CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # randomness
    seed = repeat_idx
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # random state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # data init
    dataset = PersistenceTransformDataset(diagrams_all_dir, targets, idx=torch.tensor([0, 3, 6]), eps=0.0)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    # model init
    model = PersistentHomologyTransformer(d_in=4, d_out=10) # increase input dim by 1
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for X, mask, y in dataloader:
            loss_batch = criterion(model(X, mask), y)
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for X, mask, y in dataloader:
            y_hat = model(X, mask).argmax(dim=1)
            correct += int((y_hat == y).sum())
        accuracy_train = correct / len(dataloader.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")

### Convolutional filter

#### Differentiability of persistent homology

In [None]:
dataset_diff = MNIST("./data", train=True, download=True, transform=transform)
dataloader_diff = DataLoader(dataset_diff, batch_size=2, shuffle=True)

In [None]:
images, _ = next(iter(dataloader_diff))
images.numpy()

In [None]:
# create complex
cmplx = gd.CubicalComplex(dimensions=(28, 28), top_dimensional_cells=(images[0].flatten()))

# reduce boundary matrix
cmplx.compute_persistence()

In [None]:
# critical pixels
critical_pairs = cmplx.cofaces_of_persistence_pairs()
critical_pairs

In [None]:
# persistence diagram
pd = cmplx.persistence()
pd

In [None]:
def persistence_diagram(image):

    h, w = image.shape
    img_flat = image.flatten()

    ccomplex = gd.CubicalComplex(
        dimensions = (h, w), 
        top_dimensional_cells=img_flat
    )
    
    # get pairs of critical simplices
    ccomplex.compute_persistence()
    critical_pairs = ccomplex.cofaces_of_persistence_pairs()

    # get essential critical pixels (never vanish)
    essential_features = critical_pairs[1][0]

    # 0-homology image critical pixels
    try:
        critical_pairs_0 = critical_pairs[0][0]
    except:
        critical_pairs_0 = np.empty((0, 2))
    critical_0_ver_ind = critical_pairs_0 // w
    critical_0_hor_ind = critical_pairs_0 % w
    critical_pixels_0 = np.stack([critical_0_ver_ind, critical_0_hor_ind], axis=2)

    # 0-homology essential pixels (ends with last added pixel)
    last_pixel = torch.argmax(image).item()
    essential_pixels_0 = np.array([[essential_features[0] // w, essential_features[0] % w], [last_pixel // w, last_pixel % 4]])[np.newaxis, ...]
    critical_pixels_0 = np.vstack([critical_pixels_0, essential_pixels_0])

    # 0-homology persistance diagram
    pd0 = image[critical_pixels_0[:, :, 0].flatten(), critical_pixels_0[:, :, 1].flatten()].reshape((critical_pixels_0.shape[0], 2))

    # 1-homology image critical pixels
    try:
        critical_pairs_1 = critical_pairs[0][1]
    except:
        critical_pairs_1 = np.empty((0, 2))
    critical_1_ver_ind = critical_pairs_1 // w
    critical_1_hor_ind = critical_pairs_1 % w
    critical_pixels_1 = np.stack([critical_1_ver_ind, critical_1_hor_ind], axis=2)

    # 1-homology persistance diagram
    pd1 = image[critical_pixels_1[:, :, 0].flatten(), critical_pixels_1[:, :, 1].flatten()].reshape((critical_pixels_1.shape[0], 2))

    return pd0, pd1

In [None]:
class ImagePersistence(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3))
        
    def forward(self, X):
        X_conv = self.conv(X)
        
        peristence_diagrams = []
        for i, x_conv in enumerate(X_conv):
            pd = persistence_diagram(x_conv[0])
            peristence_diagrams.append(pd)
        
        return peristence_diagrams

In [None]:
persistent_homology = ImagePersistence()
persistent_homology(images)

#### Model

In [None]:
class ConvDiagram(nn.Module):
    def __init__(self, device):
        super(ConvDiagram, self).__init__()
        self.device = device
        
    def forward(self, x):
        diagrams = []
        for i in range(x.shape[0]):
            res = diagram(x[i].flatten(), self.device)
            for j in range(len(res)):
                diagrams.append(torch.concatenate([res[j], torch.Tensor([[j, i] for _ in range(res[j].shape[0])]).to(self.device)], axis=1))
        diagrams = torch.concatenate(diagrams)
        return diagrams


class Transformer(torch.nn.Module):
    def __init__(self, n_in, n_hidden, n_out, seq_size=1024, nhead=2, num_layers=2, dim_feedforward=16):
        super(Transformer, self).__init__()
        self.embeddings = nn.Linear(n_in, n_hidden)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=n_hidden, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(seq_size, n_out)

    def forward(self, X):
        X = self.embeddings(X)
        X = self.transformer(X)
        X = X.mean(dim=-1)
        X = self.classifier(X)
        X = X.softmax(dim=-1)
        return X


class TopologicalConvTransformer(nn.Module):
    def __init__(self, n_in, n_conv, max_sequence, n_diag, n_hidden, n_out, nhead=2, num_layers=2, dim_feedforward=16, device='cuda'):
        super(TopologicalConvTransformer, self).__init__()
        
        self.max_sequence = max_sequence
        self.conv1 = nn.Conv2d(n_in, n_conv, 3)
        self.conv2 = nn.Conv2d(n_conv, n_conv, 3)
        self.conv3 = nn.Conv2d(n_conv, n_conv, 3)
        self.bn1 = nn.BatchNorm2d(n_conv)
        self.bn2 = nn.BatchNorm2d(n_conv)
        self.bn3 = nn.BatchNorm2d(n_conv)
        self.diagram = ConvDiagram(device)
        self.transformer = Transformer(n_diag, n_hidden, n_out, max_sequence, nhead, num_layers, dim_feedforward)

    def forward(self, xs):
        result = []
        for i in range(xs.shape[0]):
            x = xs[i][None, :, :] / 256
            x = self.bn1(F.gelu(self.conv1(x)))
            #x = self.bn2(F.gelu(self.conv2(x)))
            #x = self.bn3(F.gelu(self.conv3(x)))
            x = self.diagram(x)
            if x.shape[0] > self.max_sequence:
                x = x[:self.max_sequence]
            x = F.pad(x, (0, 0, 0, self.max_sequence - x.shape[0]), "constant", 0)
            x = self.transformer(x)
            result.append(x[None, :])
        result = torch.concatenate(result, axis=0)
        return result

In [None]:
kwargs = {"n_in": 1,
 "n_conv": 1,
 "max_sequence": 64,
 "n_diag": 4,
 "n_hidden": 32, "n_out": 10, "nhead": 2, "num_layers": 2, "dim_feedforward": 16, "device": "cpu"}

In [None]:
model = TopologicalConvTransformer(**kwargs)
model(images)[0]

In [None]:
%%time
n_repeats = 1
n_epochs = 20
batch_size = 64
lr = 0.002

history = np.zeros((n_repeats, n_epochs, 3))
criterion = CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # randomness
    seed = repeat_idx
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # random state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # data init
    dataset_conv = MNIST("./data", train=False, download=True, transform=transform)
    dataloader_conv = DataLoader(dataset_conv, batch_size=batch_size, shuffle=True)
    
    # model init
    model = TopologicalConvTransformer(**kwargs)
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for batch in dataloader_conv:
            loss_batch = criterion(model(batch[0]), batch[1])
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for batch in dataloader_conv:
            y_hat = model(batch[0]).argmax(dim=1)
            correct += int((y_hat == batch[1]).sum())
        accuracy_train = correct / len(dataloader_conv.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")