In [11]:
import torch.nn.functional as F 
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as Tf

import torch.nn as nn
import torch

class CovarianceDescriptor(nn.Module):

    def __init__(self):
        super().__init__()
        self.sobel_x = torch.tensor([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]])
        self.laplacian_x = torch.tensor([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]])
        

    def derivative_features(self, img):
        if img.shape == 3:
            img = T.Grayscale()(img)
        smoothed_img = Tf.gaussian_blur(img, kernel_size=(3, 3), sigma=(0.2, 0.2))
        filters = torch.stack([self.sobel_x, self.sobel_x.T, self.laplacian_x, self.laplacian_x.T], axis=0).unsqueeze(1)
        abs_derivatives = torch.abs(F.conv2d(smoothed_img, filters))
        norm_derivatives = torch.sqrt(abs_derivatives[:,0, :, :]**2 + abs_derivatives[:, 1, : , :] **2).unsqueeze(1)
        angle = torch.arctan2(abs_derivatives[:, 0, :, :], abs_derivatives[:, 1, :, :]).unsqueeze(1)
        return  [abs_derivatives, norm_derivatives, angle]
    
    def other_features(self, img):
        grid_x, grid_y = torch.meshgrid(torch.arange(img.shape[-1]), torch.arange(img.shape[-1]))
        broadcasted_grid_x = torch.broadcast_to(grid_x, (img.shape[0], img.shape[1], 28, 28)).float()/28
        broadcasted_grid_y = torch.broadcast_to(grid_y, (img.shape[0], img.shape[1], 28, 28)).float()/28
        print(broadcasted_grid_x.shape)
        return [broadcasted_grid_x[:,: ,1:-1, 1:-1], broadcasted_grid_y[:,:, 1:-1, 1:-1], img[:,:, 1:-1, 1:-1]]

    def covar(self, features, noise=1e-6):
        print(features.shape)
        assert len(features.shape)== 3
        N = features.shape[2]
        centered = features - torch.sum(features, 2,keepdim=True)/N
        cov = torch.einsum('ijk,ilk->ijl',centered,centered)  / (N)
        return cov + noise* torch.eye(9)

    def forward(self, img):  
        assert len(img.shape) == 4  
        derivative_features = self.derivative_features(img)
        other_features = self.other_features(img)
        features = torch.cat( other_features +  derivative_features, axis=1)
        vectorized = features.reshape(features.shape[0], features.shape[1], -1)
        return self.covar(vectorized)

feature_extractor = CovarianceDescriptor()

In [24]:
train_data = torchvision.datasets.MNIST("data/", train=True, download=True)
test_data = torchvision.datasets.MNIST("data/", train=False, download=True)

train_X = train_data.data.unsqueeze(1).float()/255
test_X= test_data.data.unsqueeze(1).float()/255
train_labels = train_data.targets
test_labels = test_data.targets

print(train_X.shape)
print(test_X.shape)

print(train_labels.shape)
print(test_labels.shape)

torch.Size([60000, 1, 28, 28])
torch.Size([10000, 1, 28, 28])
torch.Size([60000])
torch.Size([10000])


In [26]:
train_features = feature_extractor(train_X)
test_features = feature_extractor(test_X)
print(train_features.shape)
print(train_features.max())
print(train_features.min())
print(test_features.shape)
print(test_features.max())
print(test_features.min())

torch.Size([60000, 1, 28, 28])
torch.Size([60000, 9, 676])
torch.Size([10000, 1, 28, 28])
torch.Size([10000, 9, 676])
torch.Size([60000, 9, 9])
tensor(2.9396)
tensor(-0.1525)
torch.Size([10000, 9, 9])
tensor(2.6335)
tensor(-0.1183)


In [27]:
import os 
os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import geomstats.backend as gs
from geomstats.geometry.spd_matrices import SPDMatrices, SPDMetricAffine, SPDMetricBuresWasserstein, SPDMetricLogEuclidean, SPDMetricEuclidean
from geomstats.geometry.symmetric_matrices import SymmetricMatrices
from geomstats.geometry.matrices import Matrices
from geomstats.learning.mdm import RiemannianMinimumDistanceToMeanClassifier

In [32]:
n_classes = 10
n = 9
metric_list = [SPDMetricAffine(n), SPDMetricAffine(n, power_affine=0.5), SPDMetricAffine(n, power_affine=-0.5), SPDMetricLogEuclidean(n), SPDMetricEuclidean(n)]
for metric in metric_list:
    mdm = RiemannianMinimumDistanceToMeanClassifier(metric, n_classes)
    mdm.fit(train_features, train_labels)
    accuracy = mdm.score(test_features, test_labels)
    print(metric.__class__.__name__, accuracy)

SPDMetricAffine 0.4946
SPDMetricAffine 0.4988
SPDMetricAffine 0.4982
SPDMetricLogEuclidean 0.4978
SPDMetricEuclidean 0.4162
