In [1]:
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import geomstats.backend as gs
from geomstats.geometry.spd_matrices import (
    SPDMatrices,
    SPDMetricAffine,
    SPDMetricBuresWasserstein,
    SPDMetricLogEuclidean,
    SPDMetricEuclidean,
)
from geomstats.geometry.symmetric_matrices import SymmetricMatrices
from geomstats.geometry.matrices import Matrices
from geomstats.learning.mdm import RiemannianMinimumDistanceToMeanClassifier
from timeit import default_timer as timer
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as Tf

import torch.nn as nn
import torch
import pandas as pd
import numpy as np


INFO: Using pytorch backend
INFO: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO: NumExpr defaulting to 8 threads.


In [2]:
class CovarianceDescriptor(nn.Module):
    def __init__(self):
        super().__init__()
        self.sobel_x = torch.tensor(
            [[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]]
        )
        self.laplacian_x = torch.tensor(
            [[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]]
        )

    def derivative_features(self, img):
        if img.shape[1] == 3:
            img = T.Grayscale()(img)
        smoothed_img = Tf.gaussian_blur(img, kernel_size=(3, 3), sigma=(0.2, 0.2))
        filters = torch.stack(
            [self.sobel_x, self.sobel_x.T, self.laplacian_x, self.laplacian_x.T], axis=0
        ).unsqueeze(1)
        abs_derivatives = torch.abs(F.conv2d(smoothed_img, filters))
        norm_derivatives = torch.sqrt(
            abs_derivatives[:, 0, :, :] ** 2 + abs_derivatives[:, 1, :, :] ** 2
        ).unsqueeze(1)
        angle = torch.arctan2(
            abs_derivatives[:, 0, :, :], abs_derivatives[:, 1, :, :]
        ).unsqueeze(1)
        return [abs_derivatives, norm_derivatives, angle]

    def other_features(self, img):
        h = img.shape[-1]
        grid_x, grid_y = torch.meshgrid(
            torch.arange(img.shape[-1]), torch.arange(img.shape[-1])
        )
        broadcasted_grid_x = (
            torch.broadcast_to(
                grid_x, (img.shape[0], 1, img.shape[-1], img.shape[-2])
            ).float()
            / h
        )
        broadcasted_grid_y = (
            torch.broadcast_to(
                grid_y, (img.shape[0], 1, img.shape[-1], img.shape[-2])
            ).float()
            / h
        )
        print("grid", broadcasted_grid_x.shape)
        return [
            broadcasted_grid_x[:, :, 1:-1, 1:-1],
            broadcasted_grid_y[:, :, 1:-1, 1:-1],
            img[:, :, 1:-1, 1:-1],
        ]

    def covar(self, features, noise=1e-6):

        assert len(features.shape) == 3
        _, D, N = features.shape
        centered = features - torch.sum(features, 2, keepdim=True) / N
        cov = torch.einsum("ijk,ilk->ijl", centered, centered) / (N)
        return cov + noise * torch.eye(D)

    def forward(self, img):
        assert len(img.shape) == 4
        derivative_features = self.derivative_features(img)
        other_features = self.other_features(img)
        features = torch.cat(other_features + derivative_features, axis=1)
        vectorized = features.reshape(features.shape[0], features.shape[1], -1)
        print(vectorized.shape)
        return self.covar(vectorized)


feature_extractor = CovarianceDescriptor()


In [3]:
def get_data_processed(dataset):
    datasets = ["MNIST", "CIFAR10", "CIFAR100", "SVHN", "Flowers102", "FashionMNIST", "FER2013", "FGVCAircraft", "Food101" ]

    if dataset == "MNIST":
        train_data = torchvision.datasets.MNIST("data/", train=True, download=True)
        test_data = torchvision.datasets.MNIST("data/", train=False, download=True)
        train_X = train_data.data.unsqueeze(1).float()/255
        test_X= test_data.data.unsqueeze(1).float()/255
        train_labels = train_data.targets
        test_labels = test_data.targets
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels

    if dataset == "CIFAR10":
        train_data = torchvision.datasets.CIFAR10("data/", train=True, download=True)
        test_data = torchvision.datasets.CIFAR10("data/", train=False, download=True)
        train_X = torch.tensor(train_data.data).permute(0, 3, 1, 2).float()/255
        test_X= torch.tensor(test_data.data).permute(0, 3, 1, 2).float()/255
        train_labels = torch.tensor(train_data.targets)
        test_labels = torch.tensor(test_data.targets)
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels

    if dataset == "CIFAR100":
        train_data = torchvision.datasets.CIFAR100("data/", train=True, download=True)
        test_data = torchvision.datasets.CIFAR100("data/", train=False, download=True)
        train_X = torch.tensor(train_data.data).permute(0, 3, 1, 2).float()/255
        test_X= torch.tensor(test_data.data).permute(0, 3, 1, 2).float()/255
        train_labels = torch.tensor(train_data.targets)
        test_labels = torch.tensor(test_data.targets)
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels

    if dataset == "SVHN":
        train_data = torchvision.datasets.SVHN("data/", split='train', download=True)
        test_data = torchvision.datasets.SVHN("data/", split='test', download=True)
        train_X = torch.tensor(train_data.data).float()/255
        test_X= torch.tensor(test_data.data).float()/255
        train_labels = torch.tensor(train_data.labels)
        test_labels = torch.tensor(test_data.labels)
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels
    

    if dataset == "FashionMNIST":
        train_data = torchvision.datasets.FashionMNIST("data/", train=True, download=True)
        test_data = torchvision.datasets.FashionMNIST("data/", train=False, download=True)
        train_X = train_data.data.unsqueeze(1).float()/255
        test_X= test_data.data.unsqueeze(1).float()/255
        train_labels = train_data.targets
        test_labels = test_data.targets
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels 

    if dataset == "FGVCAircraft":
        train_data = torchvision.datasets.FGVCAircraft("data/", split='trainval', download=True)
        test_data = torchvision.datasets.FGVCAircraft("data/", split='test', download=True)
        train_X = train_data.data.unsqueeze(1).float()/255
        test_X= test_data.data.unsqueeze(1).float()/255
        train_labels = train_data.targets
        test_labels = test_data.targets
        print(train_X.shape)
        print(test_X.shape)
        print(train_labels.shape)
        print(test_labels.shape)
        return train_X, test_X, train_labels, test_labels

    if dataset == "Food101":
        train_data = torchvision.datasets.Food101("data/", train=True, download=True)
        test_data = torchvision.datasets.Food101("data/", train=False, download=True)
        train_X = train_data.data.unsqueeze(1).float()/255
        test_X= test_data.data.unsqueeze(1).float()/255
        train_labels = train_data.targets
        test_labels = test_data.targets
        return train_X, test_X, train_labels, test_labels

In [5]:
def run_forall_datasets():
    ##datasets = ["MNIST", "CIFAR10", "CIFAR100", "SVHN",  "FashionMNIST", "FER2013" ]
    datasets = ["FGVCAircraft"]
    accuracy_list =  []
    metric_list = []
    for dataset in datasets:
        train_X,test_X, train_labels ,test_labels  = get_data_processed(dataset)
        train_features = feature_extractor(train_X)
        test_features = feature_extractor(test_X)
        accuracy, metric = run_mdm((train_features, test_features, train_labels, test_labels))
        accuracy_list.append(accuracy)
        metric_list.append(metric)

    accuracy_table = pd.DataFrame(np.column_stack(accuracy_list), columns= datasets)
    time_table = pd.DataFrame(np.column_stack(metric_list), columns= datasets)
    return accuracy_table, time_table 
    
        

In [106]:
accuracy_table, time_table = run_forall_datasets()

Downloading https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz to data/fgvc-aircraft-2013b.tar.gz


0.1%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

0.4%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

0.7%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

0.9%IOPub mes

In [72]:
train_data = torchvision.datasets.Flowers102("data/", split='train', download=True)
test_data = torchvision.datasets.Flowers102("data/", split='test', download=True)

In [76]:
train_data

Dataset Flowers102
    Number of datapoints: 1020
    Root location: data/
    split=train