In [25]:
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import geomstats.backend as gs
from geomstats.geometry.spd_matrices import (
    SPDMatrices,
    SPDMetricAffine,
    SPDMetricBuresWasserstein,
    SPDMetricLogEuclidean,
    SPDMetricEuclidean,
)
from geomstats.geometry.symmetric_matrices import SymmetricMatrices
from geomstats.geometry.matrices import Matrices
from geomstats.learning.mdm import RiemannianMinimumDistanceToMeanClassifier
from timeit import default_timer as timer
import torch.nn.functional as F
import torchvision


import torchvision.transforms as T
import torchvision.transforms.functional as Tf

import torch.nn as nn
import torch
import pandas as pd
import numpy as np


In [26]:
class CovarianceDescriptor(nn.Module):
    def __init__(self):
        super().__init__()
        self.sobel_x = torch.tensor(
            [[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]]
        )
        self.laplacian_x = torch.tensor(
            [[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]]
        )

    def derivative_features(self, img):
        if img.shape[1] == 3:
            img = T.Grayscale()(img)
        smoothed_img = Tf.gaussian_blur(img, kernel_size=(3, 3), sigma=(0.2, 0.2))
        filters = torch.stack(
            [self.sobel_x, self.sobel_x.T, self.laplacian_x, self.laplacian_x.T], axis=0
        ).unsqueeze(1)
        abs_derivatives = torch.abs(F.conv2d(smoothed_img, filters))
        norm_derivatives = torch.sqrt(
            abs_derivatives[:, 0, :, :] ** 2 + abs_derivatives[:, 1, :, :] ** 2
        ).unsqueeze(1)
        angle = torch.arctan2(
            abs_derivatives[:, 0, :, :], abs_derivatives[:, 1, :, :]
        ).unsqueeze(1)
        return [abs_derivatives, norm_derivatives, angle]

    def other_features(self, img):
        h = img.shape[-1]
        grid_x, grid_y = torch.meshgrid(
            torch.arange(img.shape[-1]), torch.arange(img.shape[-1])
        )
        broadcasted_grid_x = (
            torch.broadcast_to(
                grid_x, (img.shape[0], 1, img.shape[-1], img.shape[-2])
            ).float()
            / h
        )
        broadcasted_grid_y = (
            torch.broadcast_to(
                grid_y, (img.shape[0], 1, img.shape[-1], img.shape[-2])
            ).float()
            / h
        )
        return [
            broadcasted_grid_x[:, :, 1:-1, 1:-1],
            broadcasted_grid_y[:, :, 1:-1, 1:-1],
            img[:, :, 1:-1, 1:-1],
        ]

    def covar(self, features, noise=1e-6):

        assert len(features.shape) == 3
        _, D, N = features.shape
        centered = features - torch.sum(features, 2, keepdim=True) / N
        cov = torch.einsum("ijk,ilk->ijl", centered, centered) / (N)
        return cov + noise * torch.eye(D)

    def forward(self, img):
        assert len(img.shape) == 4
        derivative_features = self.derivative_features(img)
        other_features = self.other_features(img)
        features = torch.cat(other_features + derivative_features, axis=1)
        vectorized = features.reshape(features.shape[0], features.shape[1], -1)
        print(vectorized.shape)
        return self.covar(vectorized)


feature_extractor = CovarianceDescriptor()


In [27]:
def get_data_processed(dataset):
    datasets = ["MNIST", "CIFAR10", "CIFAR100", "SVHN", "Flowers102", "FashionMNIST", "FER2013", "FGVCAircraft", "Food101" ]

    if dataset == "MNIST":
        train_data = torchvision.datasets.MNIST("data/", train=True, download=True)
        test_data = torchvision.datasets.MNIST("data/", train=False, download=True)
        train_X = train_data.data.unsqueeze(1).float()/255
        test_X= test_data.data.unsqueeze(1).float()/255
        train_labels = train_data.targets
        test_labels = test_data.targets
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels

    if dataset == "CIFAR10":
        train_data = torchvision.datasets.CIFAR10("data/", train=True, download=True)
        test_data = torchvision.datasets.CIFAR10("data/", train=False, download=True)
        train_X = torch.tensor(train_data.data).permute(0, 3, 1, 2).float()/255
        test_X= torch.tensor(test_data.data).permute(0, 3, 1, 2).float()/255
        train_labels = torch.tensor(train_data.targets)
        test_labels = torch.tensor(test_data.targets)
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels

    if dataset == "CIFAR100":
        train_data = torchvision.datasets.CIFAR100("data/", train=True, download=True)
        test_data = torchvision.datasets.CIFAR100("data/", train=False, download=True)
        train_X = torch.tensor(train_data.data).permute(0, 3, 1, 2).float()/255
        test_X= torch.tensor(test_data.data).permute(0, 3, 1, 2).float()/255
        train_labels = torch.tensor(train_data.targets)
        test_labels = torch.tensor(test_data.targets)
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels

    if dataset == "SVHN":
        train_data = torchvision.datasets.SVHN("data/", split='train', download=True)
        test_data = torchvision.datasets.SVHN("data/", split='test', download=True)
        train_X = torch.tensor(train_data.data).float()/255
        test_X= torch.tensor(test_data.data).float()/255
        train_labels = torch.tensor(train_data.labels)
        test_labels = torch.tensor(test_data.labels)
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels
    

    if dataset == "FashionMNIST":
        train_data = torchvision.datasets.FashionMNIST("data/", train=True, download=True)
        test_data = torchvision.datasets.FashionMNIST("data/", train=False, download=True)
        train_X = train_data.data.unsqueeze(1).float()/255
        test_X= test_data.data.unsqueeze(1).float()/255
        train_labels = train_data.targets
        test_labels = test_data.targets
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels 

    if dataset == "KMNIST":
        train_data = torchvision.datasets.KMNIST("data/", train=True, download=True)
        test_data = torchvision.datasets.KMNIST("data/", train=False, download=True)
        train_X = train_data.data.unsqueeze(1).float()/255
        test_X= test_data.data.unsqueeze(1).float()/255
        train_labels = train_data.targets
        test_labels = test_data.targets
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels 

In [69]:
def run_mdm(data, name):
    print("--------",name, "----------")
    train_features, test_features, train_labels, test_labels = data
    n_classes = torch.unique(train_labels).shape[0]
    n = train_features.shape[-1]
    accuracy_list = [name]
    time_list = [name]

    print("n_classes, n", n_classes, n)

    
    metric_list = [
        SPDMetricAffine(n),
        SPDMetricAffine(n, power_affine=0.5),
        SPDMetricAffine(n, power_affine=-0.5),
        SPDMetricLogEuclidean(n),
        SPDMetricEuclidean(n),
    ]
    for metric in metric_list:
        start = timer()
        mdm = RiemannianMinimumDistanceToMeanClassifier(metric, n_classes)
        mdm.fit(train_features, train_labels)
        accuracy = mdm.score(test_features, test_labels)
        elapsed = timer() - start
        print(metric.__class__.__name__, accuracy, elapsed)
        accuracy_list.append(accuracy)
        time_list.append(elapsed)
       
    print("--------",name, " ----------")
    return accuracy_list, time_list

In [70]:
def run_forall_datasets():
    datasets = ["MNIST", "CIFAR10", "CIFAR100", "SVHN",  "FashionMNIST", "KMNIST" ]
    metric_list = ["Dataset", "SPDMetricAffine", "SPDMetricAffine(0.5)", "SPDMetricAffine(-0.5)", "SPDMetricLogEuclidean", "SPDMetricEuclidean" ]
    accuracy_list =  []
    time_list = []
    for dataset in datasets:
        train_X,test_X, train_labels ,test_labels  = get_data_processed(dataset)
        train_features = feature_extractor(train_X)
        test_features = feature_extractor(test_X)
        accuracy, time = run_mdm((train_features, test_features, train_labels, test_labels), dataset)
        accuracy_list.append(accuracy)
        time_list.append(time)

        
    print(accuracy_list)
    print(time_list)
    accuracy_table = pd.DataFrame(list(accuracy_list), columns= metric_list)
    time_table = pd.DataFrame(list(time_list), columns= metric_list)
    return accuracy_table, time_table 
    
        

In [71]:
accuracy_table, time_table = run_forall_datasets()


torch.Size([60000, 1, 28, 28])
torch.Size([10000, 1, 28, 28])
torch.Size([60000])
torch.Size([10000])
grid torch.Size([60000, 1, 28, 28])
torch.Size([60000, 9, 676])
grid torch.Size([10000, 1, 28, 28])
torch.Size([10000, 9, 676])
-------- MNIST ----------
n_classes, n 10 9
SPDMetricAffine 0.4945 7.696119062995422
SPDMetricAffine 0.4986 19.092085199001303
SPDMetricAffine 0.4979 17.3718318320025
SPDMetricLogEuclidean 0.4976 12.795387423000648
SPDMetricEuclidean 0.4162 0.8786930040005245
-------- MNIST  ----------
Files already downloaded and verified
Files already downloaded and verified
torch.Size([50000, 3, 32, 32])
torch.Size([10000, 3, 32, 32])
torch.Size([50000])
torch.Size([10000])
grid torch.Size([50000, 1, 32, 32])
torch.Size([50000, 11, 900])
grid torch.Size([10000, 1, 32, 32])
torch.Size([10000, 11, 900])
-------- CIFAR10 ----------
n_classes, n 10 11
SPDMetricAffine 0.3443 12.32726083900343
SPDMetricAffine 0.3327 21.91709678400366
SPDMetricAffine 0.3327 22.373411061998922
SPDM

In [72]:
accuracy_table

Unnamed: 0,Dataset,SPDMetricAffine,SPDMetricAffine(0.5),SPDMetricAffine(-0.5),SPDMetricLogEuclidean,SPDMetricEuclidean
0,MNIST,0.4945,0.4986,0.4979,0.4976,0.4162
1,CIFAR10,0.3443,0.3327,0.3327,0.3294,0.2493
2,CIFAR100,0.1292,0.1231,0.1232,0.1222,0.0662
3,SVHN,0.239705,0.233981,0.233943,0.231638,0.247119
4,FashionMNIST,0.5596,0.5436,0.5437,0.5389,0.3827
5,KMNIST,0.3515,0.3401,0.3401,0.3378,0.2254


In [73]:
time_table

Unnamed: 0,Dataset,SPDMetricAffine,SPDMetricAffine(0.5),SPDMetricAffine(-0.5),SPDMetricLogEuclidean,SPDMetricEuclidean
0,MNIST,7.696119,19.092085,17.371832,12.795387,0.878693
1,CIFAR10,12.327261,21.917097,22.373411,14.688317,0.967956
2,CIFAR100,27.107733,52.261168,50.824133,27.120047,0.94346
3,SVHN,26.991098,56.935063,57.498821,35.220235,2.187987
4,FashionMNIST,9.598451,21.20547,21.423664,14.725171,0.882422
5,KMNIST,10.478332,21.214482,19.140688,12.241405,0.844941
