# Subspace analysis

This notebook computes the RMSD distances (normalized for number of dimensions) of the data distribution from the PCA, Patch-PCA, and downsampling subspaces. The numbers here correspond to Table 1 in the paper.

In [1]:
import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import numpy as np
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import tqdm
import seaborn
from sklearn.decomposition import PCA
sns = seaborn
def downsample(x):
    return torch.nn.AvgPool2d(2, stride=2, padding=0)(x)
def upsample(x):
    x = x.view(-1, *x.shape[-3:])
    B, _, R, _ = x.shape
    return x.reshape(B, 3, R, 1, R, 1).repeat(1, 1, 1, 2, 1, 2).reshape(B, 3, 2*R, 2*R)
def subspace_squeeze(x, factor=2):
    perp = x - upsample(downsample(x))
    return x - perp*(1-1/factor)    
import sys
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'


In [15]:
use_tqdm = False
def repeat(func, x, n):
    for _ in range(n):
        x = func(x)
    return x
def downsampling_distance(loader, times):
    if use_tqdm: loader = tqdm.tqdm(loader)
    dists = []
    for X in loader:
        if type(X) == list: X = X[0]
        X_proj = repeat(downsample, X, times)
        X_proj = repeat(upsample, X_proj, times)
        dist = ((X_proj - X)**2).sum((1,2,3)).numpy()
        dists.extend(list(dist))
    return dists

def mean_distance(loader):
    if use_tqdm: loader = tqdm.tqdm(loader)
    dists = []
    N_batches = 0
    mean = 0
    for X in loader:
        if type(X) == list: X = X[0]
        mean += X.mean(0)
        N_batches += 1
    mean = (mean / N_batches).unsqueeze(0)
    for X in loader:
        if type(X) == list: X = X[0]
        dist = ((X - mean)**2).sum((1,2,3)).numpy()
        dists.extend(list(dist))
    return dists

def pca_distance(loader, dims, max_n=None):
    if use_tqdm: loader = tqdm.tqdm(loader)
    Xs = []
    for X in loader:
        if type(X) == list: X = X[0]
        Xs.append(X)
    Xs = torch.concat(Xs)
    Xs = Xs.view(Xs.shape[0], -1).float()
    pca = PCA(n_components=dims)
    if max_n is None: max_n = Xs.shape[0]
    pca = pca.fit(Xs[:max_n])
    dists = []
    for X in loader:
        if type(X) == list: X = X[0]
        X = X.view(X.shape[0], -1)
        X_fit = pca.transform(X)
        X_fit = pca.inverse_transform(X_fit)
        dist = ((X-X_fit)**2).sum(1).numpy()
        dists.extend(list(dist))
    return dists
    
    

In [16]:
def patch(X, s):
    B, _, D, _ = X.shape
    X = X.reshape(B, 3, D//s, s, D//s, s).permute(0, 1, 3, 5, 2, 4).reshape(-1, 3 * s**2, D//s, D//s)
    return X
def unpatch(X, s):
    return X.reshape(-1, 3, s, s, D//s, D//s).permute(0, 1, 4, 2, 5, 3).reshape(-1, 3, D, D)
def patch_pca_distance(loader, s, max_n=None):
    if use_tqdm: loader = tqdm.tqdm(loader)
    Xs = []
    for X in loader:
        if type(X) == list: X = X[0]
        Xs.append(X)
    Xs = torch.concat(Xs).float()
    B, _, D, _ = Xs.shape
    Xs = patch(Xs, s).permute(0, 2, 3, 1)

    pca = PCA(n_components=3)
    if max_n is None: max_n = Xs.shape[0]

    pca = pca.fit(Xs[:max_n].reshape(-1, 3*s*s))
    
    dists = []
    for X in loader:
        if type(X) == list: X = X[0]
        X = patch(X, s).permute(0, 2, 3, 1)
        X_fit = pca.transform(X.reshape(-1, 3*s*s))
        X_fit = pca.inverse_transform(X_fit).reshape(-1, D//s, D//s, 3*s*s)
        dist = ((X-X_fit)**2).sum((1, 2, 3)).numpy()
        dists.extend(list(dist))
    return dists

# LSUN Church

In [5]:
# Numpy file is just a subset of the LSUN Church dataset. Available upon request.
church = np.load('lsun-church.npy')
church = torch.tensor(church.transpose(0,3,1,2)).float()/255
church_loader = DataLoader(church, batch_size=1000, shuffle=False)

In [6]:
msd = np.mean(pca_distance(church_loader, 32*32*3, max_n=10000))/(3*256*256*(1-1/64))
print("LSUN Church 32x32 PCA", np.sqrt(msd))

msd = np.mean(pca_distance(church_loader, 16*16*3, max_n=10000))/(3*256*256*(1-1/256))
print("LSUN Church 16x16 PCA", np.sqrt(msd))

msd = np.mean(pca_distance(church_loader, 8*8*3, max_n=10000))/(3*256*256*(1-1/1024))
print("LSUN Church 8x8 PCA", np.sqrt(msd))

LSUN Church 32x32 PCA 0.08223481725677355
LSUN Church 16x16 PCA 0.10326715360850088
LSUN Church 8x8 PCA 0.12618160595111197


In [7]:
msd = np.mean(downsampling_distance(church_loader, 1))/(3*256*256*(1-1/4))
print("LSUN Church 128x128 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(church_loader, 2))/(3*256*256*(1-1/16))
print("LSUN Church 64x64 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(church_loader, 3))/(3*256*256*(1-1/64))
print("LSUN Church 32x32 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(church_loader, 4))/(3*256*256*(1-1/256))
print("LSUN Church 16x16 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(church_loader, 5))/(3*256*256*(1-1/1024))
print("LSUN Church 8x8 downsampling", np.sqrt(msd))

msd = np.mean(mean_distance(church_loader))/(3*256*256)
print("LSUN Church 0x0", np.sqrt(msd))

LSUN Church 128x128 downsampling 0.06969019929745306
LSUN Church 64x64 downsampling 0.0883778833712964
LSUN Church 32x32 downsampling 0.10887079373374633
LSUN Church 16x16 downsampling 0.13137195950968375
LSUN Church 8x8 downsampling 0.15799366823221628
LSUN Church 0x0 0.26245155038732393


In [8]:
msd = np.mean(patch_pca_distance(church_loader, 2, max_n=1000))/(3*256*256*(1-1/4))
print("LSUN Church 128x128 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(church_loader, 4, max_n=1000))/(3*256*256*(1-1/16))
print("LSUN Church 64x64 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(church_loader, 8, max_n=1000))/(3*256*256*(1-1/64))
print("LSUN Church 32x32 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(church_loader, 16, max_n=1000))/(3*256*256*(1-1/256))
print("LSUN Church 16x16 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(church_loader, 32, max_n=1000))/(3*256*256*(1-1/1024))
print("LSUN Church 8x8 Patch-PCA", np.sqrt(msd))

LSUN Church 128x128 Patch-PCA 0.05752279336367003
LSUN Church 64x64 Patch-PCA 0.07905078133528938
LSUN Church 32x32 Patch-PCA 0.0993428842143585
LSUN Church 16x16 Patch-PCA 0.12086181618776415
LSUN Church 8x8 Patch-PCA 0.14647554526945197


# CelebA

In [9]:
# Numpy file is just a subset of the CelebA-HQ-dataset. Available upon request.
celeba = np.load('celebA-HQ-256.npy')
celeba = torch.tensor(celeba).float()/255
celeba_loader = DataLoader(celeba, batch_size=1000, shuffle=False)

In [10]:
msd = np.mean(pca_distance(celeba_loader, 32*32*3, max_n=10000))/(3*256*256*(1-1/64))
print("CelebA 32x32 PCA", np.sqrt(msd))

msd = np.mean(pca_distance(celeba_loader, 16*16*3, max_n=10000))/(3*256*256*(1-1/256))
print("CelebA 16x16 PCA", np.sqrt(msd))

msd = np.mean(pca_distance(celeba_loader, 8*8*3, max_n=10000))/(3*256*256*(1-1/1024))
print("CelebA 8x8 PCA", np.sqrt(msd))

CelebA 32x32 PCA 0.04141311471471549
CelebA 16x16 PCA 0.058937831566438996
CelebA 8x8 PCA 0.08303656842782652


In [11]:
msd = np.mean(downsampling_distance(celeba_loader, 1))/(3*256*256*(1-1/4))
print("CelebA 128x128 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(celeba_loader, 2))/(3*256*256*(1-1/16))
print("CelebA 64x64 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(celeba_loader, 3))/(3*256*256*(1-1/64))
print("CelebA 32x32 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(celeba_loader, 4))/(3*256*256*(1-1/256))
print("CelebA 16x16 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(celeba_loader, 5))/(3*256*256*(1-1/1024))
print("CelebA 8x8 downsampling", np.sqrt(msd))

msd = np.mean(mean_distance(celeba_loader))/(3*256*256)
print("CelebA 0x0", np.sqrt(msd))

CelebA 128x128 downsampling 0.034147739617184675
CelebA 64x64 downsampling 0.05084107453389564
CelebA 32x32 downsampling 0.07347487958749131
CelebA 16x16 downsampling 0.1030002890867888
CelebA 8x8 downsampling 0.14072510083194403
CelebA 0x0 0.2616690877986674


In [12]:
msd = np.mean(patch_pca_distance(celeba_loader, 2, max_n=1000))/(3*256*256*(1-1/4))
print("CelebA 128x128 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(celeba_loader, 4, max_n=1000))/(3*256*256*(1-1/16))
print("CelebA 64x64 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(celeba_loader, 8, max_n=1000))/(3*256*256*(1-1/64))
print("CelebA 32x32 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(celeba_loader, 16, max_n=1000))/(3*256*256*(1-1/256))
print("CelebA 16x16 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(celeba_loader, 32, max_n=1000))/(3*256*256*(1-1/1024))
print("CelebA 8x8 Patch-PCA", np.sqrt(msd))

CelebA 128x128 Patch-PCA 0.03414721132014191
CelebA 64x64 Patch-PCA 0.045853169749620705
CelebA 32x32 Patch-PCA 0.06341172993249068
CelebA 16x16 Patch-PCA 0.08697470398214206
CelebA 8x8 Patch-PCA 0.11673112512706059


# CIFAR

In [13]:
cifar = datasets.CIFAR10('cifar10', train=True, transform=transforms.ToTensor())
cifar_loader = DataLoader(cifar, batch_size=1000, shuffle=False)

In [17]:
msd = np.mean(downsampling_distance(cifar_loader, 1))/(3072*(1-1/4))
print("CIFAR-10 16x16 downsampling", np.sqrt(msd))

msd = np.mean(downsampling_distance(cifar_loader, 2))/(3072*(1-1/16))
print("CIFAR-10 8x8 downsampling", np.sqrt(msd))

msd = np.mean(mean_distance(cifar_loader))/(3072)
print("CIFAR-10 0x0", np.sqrt(msd))

msd = np.mean(pca_distance(cifar_loader, 16*16*3))/(3072*(1-1/4))
print("CIFAR-10 16x16 PCA", np.sqrt(msd))

msd = np.mean(pca_distance(cifar_loader, 8*8*3))/(3072*(1-1/16))
print("CIFAR-10 8x8 PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(cifar_loader, 2))/(3072*(1-1/4))
print("CIFAR-10 16x16 Patch-PCA", np.sqrt(msd))

msd = np.mean(patch_pca_distance(cifar_loader, 4))/(3072*(1-1/16))
print("CIFAR-10 8x8 Patch-PCA", np.sqrt(msd))

CIFAR-10 16x16 downsampling 0.07505646805106986
CIFAR-10 8x8 downsampling 0.11017048700915723
CIFAR-10 0x0 0.24895685839249357
CIFAR-10 16x16 PCA 0.02450205874356285
CIFAR-10 8x8 PCA 0.061117920745001775
CIFAR-10 16x16 Patch-PCA 0.0644708464728469
CIFAR-10 8x8 Patch-PCA 0.09349189440298912
