In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt


In [5]:
!pip install torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting torchvision
  Downloading torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl.metadata (6.6 kB)
Downloading torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m69.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torchvision
Successfully installed torchvision-0.17.0


In [8]:
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        # normalize specific to inception model
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# load CIFAR10 dataset
cifar10_train = CIFAR10(root='/mnt_mount/labproject_data', train=True, download=True, transform=transform)
cifar10_test = CIFAR10(root='/mnt_mount/labproject_data', train=False, download=True, transform=transform)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /mnt_mount/labproject_data/cifar-10-python.tar.gz


100.0%


Extracting /mnt_mount/labproject_data/cifar-10-python.tar.gz to /mnt_mount/labproject_data
Files already downloaded and verified


In [10]:
dataloader_1 = torch.utils.data.DataLoader(cifar10_train, batch_size=100, shuffle=False, num_workers=1)
dataloader_2 = torch.utils.data.DataLoader(cifar10_test, batch_size=100, shuffle=False, num_workers=1)

In [11]:
from torchvision.models import inception_v3
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np


# get embedding net
def get_embedding_net():
    model = inception_v3(pretrained=True)
    model.fc = torch.nn.Identity()  # replace the classifier with identity to get features
    model.eval()
    return model.to('cuda' if torch.cuda.is_available() else 'cpu')

# extract features
def extract_features(dataloader, model):
    features = []
    with torch.no_grad():
        for data, _ in dataloader:
            data = data.to('cuda' if torch.cuda.is_available() else 'cpu')
            features.append(model(data))
    return torch.cat(features).cpu().numpy()


embedding_net = get_embedding_net()

features1 = extract_features(dataloader_1, embedding_net)
features2 = extract_features(dataloader_2, embedding_net)




In [12]:
features1.shape, features2.shape

((50000, 2048), (10000, 2048))

In [16]:
from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance

swd = sliced_wasserstein_distance(torch.from_numpy(features1)[:10000], torch.from_numpy(features2), num_projections=1000)