In [None]:
import torch.nn.functional as F 
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as Tf

import torch.nn as nn
import torch

class CovarianceDescriptor(nn.Module):

    def __init__(self):
        super().__init__()
        self.sobel_x = torch.tensor([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]])
        self.laplacian_x = torch.tensor([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]])
        

    def derivative_features(self, img):
        if img.shape[1] == 3:
            img = T.Grayscale()(img)
        smoothed_img = Tf.gaussian_blur(img, kernel_size=(3, 3), sigma=(0.2, 0.2))
        filters = torch.stack([self.sobel_x, self.sobel_x.T, self.laplacian_x, self.laplacian_x.T], axis=0).unsqueeze(1)
        abs_derivatives = torch.abs(F.conv2d(smoothed_img, filters))
        norm_derivatives = torch.sqrt(abs_derivatives[:,0, :, :]**2 + abs_derivatives[:, 1, : , :] **2).unsqueeze(1)
        angle = torch.arctan2(abs_derivatives[:, 0, :, :], abs_derivatives[:, 1, :, :]).unsqueeze(1)
        return  [abs_derivatives, norm_derivatives, angle]
    
    def other_features(self, img):
        h = img.shape[-1]
        grid_x, grid_y = torch.meshgrid(torch.arange(img.shape[-1]), torch.arange(img.shape[-1]))
        broadcasted_grid_x = torch.broadcast_to(grid_x, (img.shape[0], 1, img.shape[-1], img.shape[-2])).float()/h
        broadcasted_grid_y = torch.broadcast_to(grid_y, (img.shape[0], 1, img.shape[-1], img.shape[-2])).float()/h
        print("grid", broadcasted_grid_x.shape)
        return [broadcasted_grid_x[:,: ,1:-1, 1:-1], broadcasted_grid_y[:,:, 1:-1, 1:-1], img[:,:, 1:-1, 1:-1]]

    def covar(self, features, noise=1e-6):
        
        assert len(features.shape)== 3
        _ ,D, N = features.shape
        centered = features - torch.sum(features, 2,keepdim=True)/N
        cov = torch.einsum('ijk,ilk->ijl',centered,centered)  / (N)
        return cov + noise* torch.eye(D)

    def forward(self, img):  
        assert len(img.shape) == 4  
        derivative_features = self.derivative_features(img)
        other_features = self.other_features(img)
        features = torch.cat( other_features +  derivative_features, axis=1)
        vectorized = features.reshape(features.shape[0], features.shape[1], -1)
        print(vectorized.shape)
        return self.covar(vectorized)

feature_extractor = CovarianceDescriptor()

In [None]:
datasets = ["MNIST", "CIFAR10", "CIFAR100" ]

In [None]:
train_data = torchvision.datasets.SVHN("data/", split='train', download=True)
test_data = torchvision.datasets.SVHN("data/", split='test', download=True)

train_X = train_data.data.unsqueeze(1).float()/255
test_X= test_data.data.unsqueeze(1).float()/255
train_labels = train_data.targets
test_labels = test_data.targets

print(train_X.shape)
print(test_X.shape)

print(train_labels.shape)
print(test_labels.shape)

In [None]:
train_features = feature_extractor(train_X)
test_features = feature_extractor(test_X)
print(train_features.shape)
print(train_features.max())
print(train_features.min())
print(test_features.shape)
print(test_features.max())
print(test_features.min())

In [None]:
import os 
os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import geomstats.backend as gs
from geomstats.geometry.spd_matrices import SPDMatrices, SPDMetricAffine, SPDMetricBuresWasserstein, SPDMetricLogEuclidean, SPDMetricEuclidean
from geomstats.geometry.symmetric_matrices import SymmetricMatrices
from geomstats.geometry.matrices import Matrices
from geomstats.learning.mdm import RiemannianMinimumDistanceToMeanClassifier

In [None]:
n_classes = 10
n = 9
metric_list = [SPDMetricAffine(n), SPDMetricAffine(n, power_affine=1.0), SPDMetricAffine(n, power_affine=-0.5), SPDMetricLogEuclidean(n), SPDMetricEuclidean(n)]
for metric in metric_list:
    mdm = RiemannianMinimumDistanceToMeanClassifier(metric, n_classes)
    mdm.fit(train_features, train_labels)
    accuracy = mdm.score(test_features, test_labels)
    print(metric.__class__.__name__, accuracy)

In [None]:
train_data = torchvision.datasets.SVHN("data/", split='train', download=True)
test_data = torchvision.datasets.SVHN("data/", split='test', download=True)


train_X = torch.tensor(train_data.data).float()/255
test_X= torch.tensor(test_data.data).float()/255
train_labels = torch.tensor(train_data.labels)
test_labels = torch.tensor(test_data.labels)


print(train_data.labels)

print(train_X.shape)
print(test_X.shape)

print(train_labels.shape)
print(test_labels.shape)

In [None]:
train_features = feature_extractor(train_X)
test_features = feature_extractor(test_X)
print(train_features.shape)
print(train_features.max())
print(train_features.min())
print(test_features.shape)
print(test_features.max())
print(test_features.min())

In [None]:
print(train_labels)

In [None]:
from timeit import default_timer as timer

In [None]:
n_classes = 10
n = 11
metric_list = [SPDMetricAffine(n), SPDMetricAffine(n, power_affine=0.5), SPDMetricAffine(n, power_affine=-0.5), SPDMetricLogEuclidean(n), SPDMetricEuclidean(n)]
for metric in metric_list:
    start = timer()
    mdm = RiemannianMinimumDistanceToMeanClassifier(metric, n_classes)
    mdm.fit(train_features, train_labels)
    accuracy = mdm.score(test_features, test_labels)
    elapsed = timer() - start
    print(metric.__class__.__name__, accuracy, elapsed)